diff --git a/tania_scripts/.ipynb_checkpoints/tania-some-other-metrics-checkpoint.ipynb b/tania_scripts/.ipynb_checkpoints/tania-some-other-metrics-checkpoint.ipynb index e23180be4b2b3ebb0fa7b86798291f716322a6a5..ceb6052a49e387b6e1af94e850c3bffb5c13fcd0 100644 --- a/tania_scripts/.ipynb_checkpoints/tania-some-other-metrics-checkpoint.ipynb +++ b/tania_scripts/.ipynb_checkpoints/tania-some-other-metrics-checkpoint.ipynb @@ -88,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "2a882cc9-8f9d-4457-becb-d2e26ab3f14f", "metadata": {}, "outputs": [ @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "8897dcc3-4218-4ee5-9984-17b9a6d8dce2", "metadata": {}, "outputs": [], @@ -163,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "1363f307-fa4b-43ba-93d5-2d1c11ceb9e4", "metadata": {}, "outputs": [ @@ -184,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "1362e192-514a-4a77-a8cb-5c012026e2bb", "metadata": {}, "outputs": [], @@ -238,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "544ff6aa-4104-4580-a01f-97429ffcc228", "metadata": {}, "outputs": [ @@ -328,7 +328,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "b9052dc2-ce45-4af4-a0a0-46c60a13da12", "metadata": {}, "outputs": [], @@ -388,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "1e9dd0fb-db6a-47d1-8bfb-1015845f6d3e", "metadata": {}, "outputs": [ @@ -398,7 +398,7 @@ "{'Flesch-Douma': 88.68, 'LIX': 11.55, 'Kandel-Moles': 5.86}" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -421,7 +421,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "24bc84a5-b2df-4194-838a-8f24302599bd", "metadata": {}, "outputs": [], @@ -467,7 +467,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "0cdb972f-31b6-4e7e-82a8-371eda344f2c", "metadata": {}, "outputs": [ @@ -477,7 +477,7 @@ "{'Average Word Length': 3.79, 'Average Sentence Length': 7.0}" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -521,7 +521,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "56af520c-d56b-404a-aebf-ad7c2a9ca503", "metadata": {}, "outputs": [], @@ -567,7 +567,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "f7c8b125-4651-4b21-bcc4-93ef78a4239b", "metadata": {}, "outputs": [ @@ -603,7 +603,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 15, "id": "daa17c33-adca-4695-90eb-741579382939", "metadata": {}, "outputs": [], @@ -622,7 +622,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 16, "id": "80d8fa08-6b7d-4ab7-85cd-987823639277", "metadata": {}, "outputs": [ @@ -665,7 +665,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 18, "id": "3f9c7dc7-6820-4013-a85c-2af4f846d4f5", "metadata": {}, "outputs": [ @@ -693,7 +693,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "id": "65e1a630-c46e-4b18-9831-b97864de53ee", "metadata": {}, "outputs": [], @@ -713,7 +713,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "id": "1612e911-12a8-47c9-b811-b2d6885c3647", "metadata": {}, "outputs": [ @@ -742,7 +742,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "id": "925a3a75-aaaa-4851-b77b-b42cb1e21e11", "metadata": {}, "outputs": [], @@ -757,7 +757,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "id": "6fa60897-ad26-43b4-b8de-861290ca6bd3", "metadata": {}, "outputs": [ @@ -783,25 +783,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 23, "id": "f3678462-e572-4ce5-8d3d-a5389b2356c8", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Defaulting to user installation because normal site-packages is not writeable\n", - "Collecting scipy\n", - " Downloading scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n", - "Requirement already satisfied: numpy<2.5,>=1.23.5 in /public/conda/Miniconda/envs/pytorch-2.6/lib/python3.11/site-packages (from scipy) (2.2.4)\n", - "Downloading scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m37.7/37.7 MB\u001b[0m \u001b[31m75.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", - "\u001b[?25hInstalling collected packages: scipy\n", - "Successfully installed scipy-1.15.3\n" - ] - } - ], + "outputs": [], "source": [ "#!pip3 install seaborn\n", "#!pip3 install scipy" @@ -809,7 +794,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 24, "id": "b621b2a8-488f-44db-b085-fe156f453943", "metadata": {}, "outputs": [ @@ -830,7 +815,7 @@ "import matplotlib.pyplot as plt\n", "from scipy.stats import spearmanr\n", "\n", - "# Sample data (replace with your real values)\n", + "# Sample data (to be replaces with real values)\n", "data = {\n", " \"perplexity\": [32.5, 45.2, 28.1, 39.0, 50.3],\n", " \"avg_word_length\": [4.1, 4.3, 4.0, 4.2, 4.5],\n", @@ -853,10 +838,123 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "45ee04fc-acab-4bba-ba06-e4cf4bca9fe5", + "metadata": {}, + "source": [ + "## Tree depth" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "79f99787-c220-4f1d-93a9-59230363ec3f", + "metadata": {}, + "outputs": [], + "source": [ + "def parse_sentence_block(text):\n", + " lines = text.strip().split('\\n')\n", + " result = []\n", + " tokenlist = []\n", + " for line in lines:\n", + " # Split the line by tab and strip whitespace\n", + " parts = tuple(line.strip().split('\\t'))\n", + " # Only include lines that have exactly 4 parts\n", + " if len(parts) == 4:\n", + " parentidx = int(parts[3])\n", + " if '@@' in parts[2]:\n", + " nonterm1 = parts[2].split('@@')[0]\n", + " nonterm2 = parts[2].split('@@')[1]\n", + " else:\n", + " nonterm1 = parts[2]\n", + " nonterm2 = '<nul>'\n", + " postag = parts[1]\n", + " token = parts[0]\n", + " result.append((parentidx, nonterm1, nonterm2, postag))\n", + " tokenlist.append(token)\n", + " return result, tokenlist\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "f567efb0-8b0b-4782-9345-052cf1785776", + "metadata": {}, + "outputs": [], + "source": [ + "example_sentence = \"\"\"\n", + "<s>\t<s>\t<s>\t1\n", + "--\tponct\t<nul>@@<nul>\t1\n", + "Eh\tnpp\t<nul>@@<nul>\t1\n", + "bien?\tadv\tAP@@<nul>\t1\n", + "fit\tv\tVN@@<nul>\t2\n", + "-il\tcls-suj\tVN@@VPinf-OBJ\t3\n", + ".\tponct\t<nul>@@<nul>\t4\n", + "</s>\t</s>\t</s>\t4\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "8d4ecba9-89b8-4000-a061-aa16aa68a404", + "metadata": {}, + "outputs": [], + "source": [ + "from transform import *\n", + "\n", + "def visualize_const_prediction(example_sent):\n", + " parsed, tokenlist = parse_sentence_block(example_sent)\n", + " tree = AttachJuxtaposeTree.totree(tokenlist, 'SENT')\n", + " AttachJuxtaposeTree.action2tree(tree, parsed).pretty_print()\n", + " nltk_tree = AttachJuxtaposeTree.action2tree(tree, parsed)\n", + " #print(\"NLTK TREE\", nltk_tree)\n", + " depth = nltk_tree.height() - 1 # NLTK includes the leaf level as height 1, so subtract 1 for tree depth \n", + " print(\"Tree depth:\", depth)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "bfd3abf3-b83a-4817-85ad-654daf72be88", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " SENT \n", + " | \n", + " <s> \n", + " | \n", + " <s> \n", + " ______________|__________ \n", + " | | | AP \n", + " | | | __________|________ \n", + " | | | | VPinf-OBJ \n", + " | | | | ______________|_______ \n", + " | | | | | VN \n", + " | | | | | ________________|____ \n", + " | | | | VN | | </s>\n", + " | | | | | | | | \n", + " | ponct npp adv v cls-suj ponct </s>\n", + " | | | | | | | | \n", + "<s> -- Eh bien? fit -il . </s>\n", + "\n", + "Tree depth: 8\n" + ] + } + ], + "source": [ + "visualize_const_prediction(example_sentence)" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "3a6e3b53-7104-45ef-a4b5-e831bdd6ca6f", + "id": "bc51ab44-6885-45cc-bad2-6a43a7791fdb", "metadata": {}, "outputs": [], "source": [] diff --git a/tania_scripts/.ipynb_checkpoints/transform-checkpoint.py b/tania_scripts/.ipynb_checkpoints/transform-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..56e0f5158b6b3b214b7e99e8a479fbabc56bcbf6 --- /dev/null +++ b/tania_scripts/.ipynb_checkpoints/transform-checkpoint.py @@ -0,0 +1,459 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union + +import nltk +import torch + +from supar.models.const.crf.transform import Tree +from supar.utils.common import NUL +from supar.utils.logging import get_logger +from supar.utils.tokenizer import Tokenizer +from supar.utils.transform import Sentence + +if TYPE_CHECKING: + from supar.utils import Field + +logger = get_logger(__name__) + + +class AttachJuxtaposeTree(Tree): + r""" + :class:`AttachJuxtaposeTree` is derived from the :class:`Tree` class, + supporting back-and-forth transformations between trees and AttachJuxtapose actions :cite:`yang-deng-2020-aj`. + + Attributes: + WORD: + Words in the sentence. + POS: + Part-of-speech tags, or underscores if not available. + TREE: + The raw constituency tree in :class:`nltk.tree.Tree` format. + NODE: + The target node on each rightmost chain. + PARENT: + The label of the parent node of each terminal. + NEW: + The label of each newly inserted non-terminal with a target node and a terminal as juxtaposed children. + ``NUL`` represents the `Attach` action. + """ + + fields = ['WORD', 'POS', 'TREE', 'NODE', 'PARENT', 'NEW'] + + def __init__( + self, + WORD: Optional[Union[Field, Iterable[Field]]] = None, + POS: Optional[Union[Field, Iterable[Field]]] = None, + TREE: Optional[Union[Field, Iterable[Field]]] = None, + NODE: Optional[Union[Field, Iterable[Field]]] = None, + PARENT: Optional[Union[Field, Iterable[Field]]] = None, + NEW: Optional[Union[Field, Iterable[Field]]] = None + ) -> Tree: + super().__init__() + + self.WORD = WORD + self.POS = POS + self.TREE = TREE + self.NODE = NODE + self.PARENT = PARENT + self.NEW = NEW + + @property + def src(self): + return self.WORD, self.POS, self.TREE + + @property + def tgt(self): + return self.NODE, self.PARENT, self.NEW + + @classmethod + def tree2action(cls, tree: nltk.Tree): + r""" + Converts a constituency tree into AttachJuxtapose actions. + + Args: + tree (nltk.tree.Tree): + A constituency tree in :class:`nltk.tree.Tree` format. + + Returns: + A sequence of AttachJuxtapose actions. + + Examples: + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree + >>> tree = nltk.Tree.fromstring(''' + (TOP + (S + (NP (_ Arthur)) + (VP + (_ is) + (NP (NP (_ King)) (PP (_ of) (NP (_ the) (_ Britons))))) + (_ .))) + ''') + >>> tree.pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + >>> AttachJuxtaposeTree.tree2action(tree) + [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'), + (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'), + (0, '<nul>', '<nul>')] + """ + + def isroot(node): + return node == tree[0] + + def isterminal(node): + return len(node) == 1 and not isinstance(node[0], nltk.Tree) + + def last_leaf(node): + pos = () + while True: + pos += (len(node) - 1,) + node = node[-1] + if isterminal(node): + return node, pos + + def parent(position): + return tree[position[:-1]] + + def grand(position): + return tree[position[:-2]] + + def detach(tree): + last, last_pos = last_leaf(tree) + siblings = parent(last_pos)[:-1] + + if len(siblings) > 0: + last_subtree = last + last_subtree_siblings = siblings + parent_label = NUL + else: + last_subtree, last_pos = parent(last_pos), last_pos[:-1] + last_subtree_siblings = [] if isroot(last_subtree) else parent(last_pos)[:-1] + parent_label = last_subtree.label() + + target_pos, new_label, last_tree = 0, NUL, tree + if isroot(last_subtree): + last_tree = None + + elif len(last_subtree_siblings) == 1 and not isterminal(last_subtree_siblings[0]): + new_label = parent(last_pos).label() + new_label = new_label + target = last_subtree_siblings[0] + last_grand = grand(last_pos) + if last_grand is None: + last_tree = targetistermina + else: + last_grand[-1] = target + target_pos = len(last_pos) - 2 + else: + target = parent(last_pos) + target.pop() + target_pos = len(last_pos) - 2 + action = target_pos, parent_label, new_label + return action, last_tree + if tree is None: + return [] + action, last_tree = detach(tree) + return cls.tree2action(last_tree) + [action] + + @classmethod + def action2tree( + cls, + tree: nltk.Tree, + actions: List[Tuple[int, str, str]], + join: str = '::', + ) -> nltk.Tree: + r""" + Recovers a constituency tree from a sequence of AttachJuxtapose actions. + + Args: + tree (nltk.tree.Tree): + An empty tree that provides a base for building a result tree. + actions (List[Tuple[int, str, str]]): + A sequence of AttachJuxtapose actions. + join (str): + A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains. + Default: ``'::'``. + + Returns: + A result constituency tree. + + Examples: + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree + >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP') + >>> AttachJuxtaposeTree.action2tree(tree, + [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'), + (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'), + (0, '<nul>', '<nul>')]).pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + """ + + def target(node, depth): + node_pos = () + for _ in range(depth): + node_pos += (len(node) - 1,) + node = node[-1] + return node, node_pos + + def parent(tree, position): + return tree[position[:-1]] + + def execute(tree: nltk.Tree, terminal: Tuple(str, str), action: Tuple[int, str, str]) -> nltk.Tree: + target_pos, parent_label, new_label, post = action + #print(target_pos, parent_label, new_label) + new_leaf = nltk.Tree(post, [terminal[0]]) + + # create the subtree to be inserted + new_subtree = new_leaf if parent_label == NUL else nltk.Tree(parent_label, [new_leaf]) + # find the target position at which to insert the new subtree + target_node = tree + if target_node is not None: + target_node, target_pos = target(target_node, target_pos) + + # Attach + if new_label == NUL: + # attach the first token + if target_node is None: + return new_subtree + target_node.append(new_subtree) + # Juxtapose + else: + new_subtree = nltk.Tree(new_label, [target_node, new_subtree]) + if len(target_pos) > 0: + parent_node = parent(tree, target_pos) + parent_node[-1] = new_subtree + else: + tree = new_subtree + return tree + + tree, root, terminals = None, tree.label(), tree.pos() + for terminal, action in zip(terminals, actions): + tree = execute(tree, terminal, action) + # recover unary chains + nodes = [tree] + while nodes: + node = nodes.pop() + if isinstance(node, nltk.Tree): + nodes.extend(node) + if join in node.label(): + labels = node.label().split(join) + node.set_label(labels[0]) + subtree = nltk.Tree(labels[-1], node) + for label in reversed(labels[1:-1]): + subtree = nltk.Tree(label, [subtree]) + node[:] = [subtree] + return nltk.Tree(root, [tree]) + + @classmethod + def action2span( + cls, + action: torch.Tensor, + spans: torch.Tensor = None, + nul_index: int = -1, + mask: torch.BoolTensor = None + ) -> torch.Tensor: + r""" + Converts a batch of the tensorized action at a given step into spans. + + Args: + action (~torch.Tensor): ``[3, batch_size]``. + A batch of the tensorized action at a given step, containing indices of target nodes, parent and new labels. + spans (~torch.Tensor): + Spans generated at previous steps, ``None`` at the first step. Default: ``None``. + nul_index (int): + The index for the obj:`NUL` token, representing the Attach action. Default: -1. + mask (~torch.BoolTensor): ``[batch_size]``. + The mask for covering the unpadded tokens. + + Returns: + A tensor representing a batch of spans for the given step. + + Examples: + >>> from collections import Counter + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree, Vocab + >>> from supar.utils.common import NUL + >>> nodes, parents, news = zip(*[(0, 'NP', NUL), (0, 'VP', 'S'), (1, 'NP', NUL), + (2, 'PP', 'NP'), (3, 'NP', NUL), (4, NUL, NUL), + (0, NUL, NUL)]) + >>> vocab = Vocab(Counter(sorted(set([*parents, *news])))) + >>> actions = torch.tensor([nodes, vocab[parents], vocab[news]]).unsqueeze(1) + >>> spans = None + >>> for action in actions.unbind(-1): + ... spans = AttachJuxtaposeTree.action2span(action, spans, vocab[NUL]) + ... + >>> spans + tensor([[[-1, 1, -1, -1, -1, -1, -1, 3], + [-1, -1, -1, -1, -1, -1, 4, -1], + [-1, -1, -1, 1, -1, -1, 1, -1], + [-1, -1, -1, -1, -1, -1, 2, -1], + [-1, -1, -1, -1, -1, -1, 1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1]]]) + >>> sequence = torch.where(spans.ge(0)) + >>> sequence = list(zip(sequence[1].tolist(), sequence[2].tolist(), vocab[spans[sequence]])) + >>> sequence + [(0, 1, 'NP'), (0, 7, 'S'), (1, 6, 'VP'), (2, 3, 'NP'), (2, 6, 'NP'), (3, 6, 'PP'), (4, 6, 'NP')] + >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP') + >>> AttachJuxtaposeTree.build(tree, sequence).pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + + """ + + # [batch_size] + target, parent, new = action + if spans is None: + spans = action.new_full((action.shape[1], 2, 2), -1) + spans[:, 0, 1] = parent + return spans + if mask is None: + mask = torch.ones_like(target, dtype=bool) + juxtapose_mask = new.ne(nul_index) & mask + # ancestor nodes are those on the rightmost chain and higher than the target node + # [batch_size, seq_len] + rightmost_mask = spans[..., -1].ge(0) + ancestors = rightmost_mask.cumsum(-1).masked_fill_(~rightmost_mask, -1) - 1 + # should not include the target node for the Juxtapose action + ancestor_mask = mask.unsqueeze(-1) & ancestors.ge(0) & ancestors.le((target - juxtapose_mask.long()).unsqueeze(-1)) + target_pos = torch.where(ancestors.eq(target.unsqueeze(-1))[juxtapose_mask])[-1] + # the right boundaries of ancestor nodes should be aligned with the new generated terminals + spans = torch.cat((spans, torch.where(ancestor_mask, spans[..., -1], -1).unsqueeze(-1)), -1) + spans[..., -2].masked_fill_(ancestor_mask, -1) + spans[juxtapose_mask, target_pos, -1] = new.masked_fill(new.eq(nul_index), -1)[juxtapose_mask] + spans[mask, -1, -1] = parent.masked_fill(parent.eq(nul_index), -1)[mask] + # [batch_size, seq_len+1, seq_len+1] + spans = torch.cat((spans, torch.full_like(spans[:, :1], -1)), 1) + return spans + + def load( + self, + data: Union[str, Iterable], + lang: Optional[str] = None, + **kwargs + ) -> List[AttachJuxtaposeTreeSentence]: + r""" + Args: + data (Union[str, Iterable]): + A filename or a list of instances. + lang (str): + Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. + ``None`` if tokenization is not required. + Default: ``None``. + + Returns: + A list of :class:`AttachJuxtaposeTreeSentence` instances. + """ + + if lang is not None: + tokenizer = Tokenizer(lang) + if isinstance(data, str) and os.path.exists(data): + if data.endswith('.txt'): + data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1) + else: + data = open(data) + else: + if lang is not None: + data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)] + else: + data = [data] if isinstance(data[0], str) else data + + index = 0 + for s in data: + + try: + tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root) + sentence = AttachJuxtaposeTreeSentence(self, tree, index) + except ValueError: + logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!") + continue + except IndexError: + tree = nltk.Tree.fromstring('(S ' + s + ')') + sentence = AttachJuxtaposeTreeSentence(self, tree, index) + else: + yield sentence + index += 1 + self.root = tree.label() + + +class AttachJuxtaposeTreeSentence(Sentence): + r""" + Args: + transform (AttachJuxtaposeTree): + A :class:`AttachJuxtaposeTree` object. + tree (nltk.tree.Tree): + A :class:`nltk.tree.Tree` object. + index (Optional[int]): + Index of the sentence in the corpus. Default: ``None``. + """ + + def __init__( + self, + transform: AttachJuxtaposeTree, + tree: nltk.Tree, + index: Optional[int] = None + ) -> AttachJuxtaposeTreeSentence: + super().__init__(transform, index) + + words, tags = zip(*tree.pos()) + nodes, parents, news = None, None, None + if transform.training: + oracle_tree = tree.copy(True) + # the root node must have a unary chain + if len(oracle_tree) > 1: + oracle_tree[:] = [nltk.Tree('*', oracle_tree)] + oracle_tree.collapse_unary(joinChar='::') + if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree): + oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]]) + nodes, parents, news = zip(*transform.tree2action(oracle_tree)) + tags = [x.split("##")[0] for x in tags] + self.values = [words, tags, tree, nodes, parents, news] + + def __repr__(self): + return self.values[-4].pformat(1000000) + + def pretty_print(self): + self.values[-4].pretty_print() diff --git a/tania_scripts/__pycache__/transform.cpython-311.pyc b/tania_scripts/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa71196328a1158b12ae2fffe91d227ae067efc Binary files /dev/null and b/tania_scripts/__pycache__/transform.cpython-311.pyc differ diff --git a/tania_scripts/supar/.ipynb_checkpoints/__init__-checkpoint.py b/tania_scripts/supar/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..146f3df423731c6ccad8c46a95acf8c26649da6e --- /dev/null +++ b/tania_scripts/supar/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- + +from .models import (AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos, + BiaffineDependencyParser, + BiaffineSemanticDependencyParser, CRF2oDependencyParser, + CRFConstituencyParser, CRFDependencyParser, SLDependencyParser, ArcEagerDependencyParser, + TetraTaggingConstituencyParser, VIConstituencyParser, SLConstituentParser, + VIDependencyParser, VISemanticDependencyParser) +from .parser import Parser +from .structs import (BiLexicalizedConstituencyCRF, ConstituencyCRF, + ConstituencyLBP, ConstituencyMFVI, Dependency2oCRF, + DependencyCRF, DependencyLBP, DependencyMFVI, + LinearChainCRF, MatrixTree, SemanticDependencyLBP, + SemanticDependencyMFVI, SemiMarkovCRF) + + + +__all__ = ['Parser', + + # Dependency Parsing + 'BiaffineDependencyParser', + 'CRFDependencyParser', + 'CRF2oDependencyParser', + 'VIDependencyParser', + 'SLDependencyParser', + 'ArcEagerDependencyParser', + + # Constituent Parsing + 'AttachJuxtaposeConstituencyParser', + 'CRFConstituencyParser', + 'TetraTaggingConstituencyParser', + 'VIConstituencyParser', + 'SLConstituentParser', + + # Semantic Parsing + 'BiaffineSemanticDependencyParser', + 'VISemanticDependencyParser', + 'LinearChainCRF', + 'SemiMarkovCRF', + + # transforms + 'MatrixTree', + 'DependencyCRF', + 'Dependency2oCRF', + 'ConstituencyCRF', + 'BiLexicalizedConstituencyCRF', + 'DependencyLBP', + 'DependencyMFVI', + 'ConstituencyLBP', + 'ConstituencyMFVI', + 'SemanticDependencyLBP', + 'SemanticDependencyMFVI'] + +__version__ = '1.1.4' + +PARSER = {parser.NAME: parser for parser in [BiaffineDependencyParser, + CRFDependencyParser, + CRF2oDependencyParser, + VIDependencyParser, + SLDependencyParser, + ArcEagerDependencyParser, + AttachJuxtaposeConstituencyParser, + CRFConstituencyParser, + TetraTaggingConstituencyParser, + VIConstituencyParser, + SLConstituentParser, + BiaffineSemanticDependencyParser, + VISemanticDependencyParser]} + +SRC = {'github': 'https://github.com/yzhangcs/parser/releases/download', + 'hlt': 'http://hlt.suda.edu.cn/~yzhang/supar'} +NAME = { + 'biaffine-dep-en': 'ptb.biaffine.dep.lstm.char', + 'biaffine-dep-zh': 'ctb7.biaffine.dep.lstm.char', + 'crf2o-dep-en': 'ptb.crf2o.dep.lstm.char', + 'crf2o-dep-zh': 'ctb7.crf2o.dep.lstm.char', + 'biaffine-dep-roberta-en': 'ptb.biaffine.dep.roberta', + 'biaffine-dep-electra-zh': 'ctb7.biaffine.dep.electra', + 'biaffine-dep-xlmr': 'ud.biaffine.dep.xlmr', + 'crf-con-en': 'ptb.crf.con.lstm.char', + 'crf-con-zh': 'ctb7.crf.con.lstm.char', + 'crf-con-roberta-en': 'ptb.crf.con.roberta', + 'crf-con-electra-zh': 'ctb7.crf.con.electra', + 'crf-con-xlmr': 'spmrl.crf.con.xlmr', + 'biaffine-sdp-en': 'dm.biaffine.sdp.lstm.tag-char-lemma', + 'biaffine-sdp-zh': 'semeval16.biaffine.sdp.lstm.tag-char-lemma', + 'vi-sdp-en': 'dm.vi.sdp.lstm.tag-char-lemma', + 'vi-sdp-zh': 'semeval16.vi.sdp.lstm.tag-char-lemma', + 'vi-sdp-roberta-en': 'dm.vi.sdp.roberta', + 'vi-sdp-electra-zh': 'semeval16.vi.sdp.electra' +} +MODEL = {src: {n: f"{link}/v1.1.0/{m}.zip" for n, m in NAME.items()} for src, link in SRC.items()} +CONFIG = {src: {n: f"{link}/v1.1.0/{m}.ini" for n, m in NAME.items()} for src, link in SRC.items()} + + +def compatible(): + import sys + supar = sys.modules[__name__] + if supar.__version__ < '1.2': + sys.modules['supar.utils.transform'].CoNLL = supar.models.dep.biaffine.transform.CoNLL + sys.modules['supar.utils.transform'].Tree = supar.models.const.crf.transform.Tree + sys.modules['supar.parsers'] = supar.models + sys.modules['supar.parsers.con'] = supar.models.const + + +compatible() diff --git a/tania_scripts/supar/.ipynb_checkpoints/parser-checkpoint.py b/tania_scripts/supar/.ipynb_checkpoints/parser-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f1c372e52475d87780f414396606f71a7494653b --- /dev/null +++ b/tania_scripts/supar/.ipynb_checkpoints/parser-checkpoint.py @@ -0,0 +1,620 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import contextlib +import os +import shutil +import sys +import tempfile +import pickle +from contextlib import contextmanager +from datetime import datetime, timedelta +from typing import Any, Iterable, Union + +import dill +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.cuda.amp import GradScaler +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler + +import supar +from supar.utils import Config, Dataset +from supar.utils.field import Field +from supar.utils.fn import download, get_rng_state, set_rng_state +from supar.utils.logging import get_logger, init_logger, progress_bar +from supar.utils.metric import Metric +from supar.utils.optim import InverseSquareRootLR, LinearLR +from supar.utils.parallel import DistributedDataParallel as DDP +from supar.utils.parallel import gather, is_dist, is_master, reduce +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class Parser(object): + + NAME = None + MODEL = None + + def __init__(self, args, model, transform): + self.args = args + self.model = model + self.transform = transform + + @property + def device(self): + return 'cuda' if torch.cuda.is_available() else 'cpu' + + @property + def sync_grad(self): + return self.step % self.args.update_steps == 0 or self.step % self.n_batches == 0 + + @contextmanager + def sync(self): + context = getattr(contextlib, 'suppress' if sys.version < '3.7' else 'nullcontext') + if is_dist() and not self.sync_grad: + context = self.model.no_sync + with context(): + yield + + @contextmanager + def join(self): + context = getattr(contextlib, 'suppress' if sys.version < '3.7' else 'nullcontext') + if not is_dist(): + with context(): + yield + elif self.model.training: + with self.model.join(): + yield + else: + try: + dist_model = self.model + # https://github.com/pytorch/pytorch/issues/54059 + if hasattr(self.model, 'module'): + self.model = self.model.module + yield + finally: + self.model = dist_model + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int, + patience: int, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + clip: float = 5.0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ) -> None: + r""" + Args: + train/dev/test (Union[str, Iterable]): + Filenames of the train/dev/test datasets. + epochs (int): + The number of training iterations. + patience (int): + The number of consecutive iterations after which the training process would be early stopped if no improvement. + batch_size (int): + The number of tokens in each batch. Default: 5000. + update_steps (int): + Gradient accumulation steps. Default: 1. + buckets (int): + The number of buckets that sentences are assigned to. Default: 32. + workers (int): + The number of subprocesses used for data loading. 0 means only the main process. Default: 0. + clip (float): + Clips gradient of an iterable of parameters at specified value. Default: 5.0. + amp (bool): + Specifies whether to use automatic mixed precision. Default: ``False``. + cache (bool): + If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``. + verbose (bool): + If ``True``, increases the output verbosity. Default: ``True``. + """ + + args = self.args.update(locals()) + init_logger(logger, verbose=args.verbose) + + self.transform.train() + batch_size = batch_size // update_steps + eval_batch_size = args.get('eval_batch_size', batch_size) + if is_dist(): + batch_size = batch_size // dist.get_world_size() + eval_batch_size = eval_batch_size // dist.get_world_size() + logger.info("Loading the data") + if args.cache: + args.bin = os.path.join(os.path.dirname(args.path), 'bin') + args.even = args.get('even', is_dist()) + + train = Dataset(self.transform, args.train, **args).build(batch_size=batch_size, + n_buckets=buckets, + shuffle=True, + distributed=is_dist(), + even=args.even, + n_workers=workers) + dev = Dataset(self.transform, args.dev, **args).build(batch_size=eval_batch_size, + n_buckets=buckets, + shuffle=False, + distributed=is_dist(), + even=False, + n_workers=workers) + logger.info(f"{'train:':6} {train}") + if not args.test: + logger.info(f"{'dev:':6} {dev}\n") + else: + test = Dataset(self.transform, args.test, **args).build(batch_size=eval_batch_size, + n_buckets=buckets, + shuffle=False, + distributed=is_dist(), + even=False, + n_workers=workers) + logger.info(f"{'dev:':6} {dev}") + logger.info(f"{'test:':6} {test}\n") + loader, sampler = train.loader, train.loader.batch_sampler + args.steps = len(loader) * epochs // args.update_steps + args.save(f"{args.path}.yaml") + + self.optimizer = self.init_optimizer() + self.scheduler = self.init_scheduler() + self.scaler = GradScaler(enabled=args.amp) + + if dist.is_initialized(): + self.model = DDP(module=self.model, + device_ids=[args.local_rank], + find_unused_parameters=args.get('find_unused_parameters', True), + static_graph=args.get('static_graph', False)) + if args.amp: + from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook + self.model.register_comm_hook(dist.group.WORLD, fp16_compress_hook) + if args.wandb and is_master(): + import wandb + # start a new wandb run to track this script + wandb.init(config=args.primitive_config, + project=args.get('project', self.NAME), + name=args.get('name', args.path), + resume=self.args.checkpoint) + self.step, self.epoch, self.best_e, self.patience = 1, 1, 1, patience + # uneven batches are excluded + self.n_batches = min(gather(len(loader))) if is_dist() else len(loader) + self.best_metric, self.elapsed = Metric(), timedelta() + if args.checkpoint: + try: + self.optimizer.load_state_dict(self.checkpoint_state_dict.pop('optimizer_state_dict')) + self.scheduler.load_state_dict(self.checkpoint_state_dict.pop('scheduler_state_dict')) + self.scaler.load_state_dict(self.checkpoint_state_dict.pop('scaler_state_dict')) + set_rng_state(self.checkpoint_state_dict.pop('rng_state')) + for k, v in self.checkpoint_state_dict.items(): + setattr(self, k, v) + sampler.set_epoch(self.epoch) + except AttributeError: + logger.warning("No checkpoint found. Try re-launching the training procedure instead") + + for epoch in range(self.epoch, args.epochs + 1): + start = datetime.now() + bar, metric = progress_bar(loader), Metric() + + logger.info(f"Epoch {epoch} / {args.epochs}:") + self.model.train() + with self.join(): + # we should reset `step` as the number of batches in different processes is not necessarily equal + self.step = 1 + for batch in bar: + with self.sync(): + with torch.autocast(self.device, enabled=args.amp): + loss = self.train_step(batch) + self.backward(loss) + if self.sync_grad: + self.clip_grad_norm_(self.model.parameters(), args.clip) + self.scaler.step(self.optimizer) + self.scaler.update() + self.scheduler.step() + self.optimizer.zero_grad(True) + + bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f}") + # log metrics to wandb + if args.wandb and is_master(): + wandb.log({'lr': self.scheduler.get_last_lr()[0], 'loss': loss}) + self.step += 1 + logger.info(f"{bar.postfix}") + self.model.eval() + with self.join(), torch.autocast(self.device, enabled=args.amp): + metric = self.reduce(sum([self.eval_step(i) for i in progress_bar(dev.loader)], Metric())) + logger.info(f"{'dev:':5} {metric}") + if args.wandb and is_master(): + wandb.log({'dev': metric.values, 'epochs': epoch}) + if args.test: + test_metric = sum([self.eval_step(i) for i in progress_bar(test.loader)], Metric()) + logger.info(f"{'test:':5} {self.reduce(test_metric)}") + if args.wandb and is_master(): + wandb.log({'test': test_metric.values, 'epochs': epoch}) + + t = datetime.now() - start + self.epoch += 1 + self.patience -= 1 + self.elapsed += t + + if metric > self.best_metric: + self.best_e, self.patience, self.best_metric = epoch, patience, metric + if is_master(): + self.save_checkpoint(args.path) + logger.info(f"{t}s elapsed (saved)\n") + else: + logger.info(f"{t}s elapsed\n") + if self.patience < 1: + break + if is_dist(): + dist.barrier() + + best = self.load(**args) + # only allow the master device to save models + if is_master(): + best.save(args.path) + + logger.info(f"Epoch {self.best_e} saved") + logger.info(f"{'dev:':5} {self.best_metric}") + if args.test: + best.model.eval() + with best.join(): + test_metric = sum([best.eval_step(i) for i in progress_bar(test.loader)], Metric()) + logger.info(f"{'test:':5} {best.reduce(test_metric)}") + logger.info(f"{self.elapsed}s elapsed, {self.elapsed / epoch}s/epoch") + if args.wandb and is_master(): + wandb.finish() + + self.evaluate(data=args.test, batch_size=batch_size) + self.predict(args.test, batch_size=batch_size, buckets=buckets, workers=workers) + + with open(f'{self.args.folder}/status', 'w') as file: + file.write('finished') + + + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + r""" + Args: + data (Union[str, Iterable]): + The data for evaluation. Both a filename and a list of instances are allowed. + batch_size (int): + The number of tokens in each batch. Default: 5000. + buckets (int): + The number of buckets that sentences are assigned to. Default: 8. + workers (int): + The number of subprocesses used for data loading. 0 means only the main process. Default: 0. + amp (bool): + Specifies whether to use automatic mixed precision. Default: ``False``. + cache (bool): + If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``. + verbose (bool): + If ``True``, increases the output verbosity. Default: ``True``. + + Returns: + The evaluation results. + """ + + args = self.args.update(locals()) + init_logger(logger, verbose=args.verbose) + + self.transform.train() + logger.info("Loading the data") + if args.cache: + args.bin = os.path.join(os.path.dirname(args.path), 'bin') + if is_dist(): + batch_size = batch_size // dist.get_world_size() + data = Dataset(self.transform, **args) + data.build(batch_size=batch_size, + n_buckets=buckets, + shuffle=False, + distributed=is_dist(), + even=False, + n_workers=workers) + logger.info(f"\n{data}") + + logger.info("Evaluating the data") + start = datetime.now() + self.model.eval() + with self.join(): + bar, metric = progress_bar(data.loader), Metric() + for batch in bar: + metric += self.eval_step(batch) + bar.set_postfix_str(metric) + metric = self.reduce(metric) + elapsed = datetime.now() - start + logger.info(f"{metric}") + logger.info(f"{elapsed}s elapsed, " + f"{sum(data.sizes)/elapsed.total_seconds():.2f} Tokens/s, " + f"{len(data)/elapsed.total_seconds():.2f} Sents/s") + os.makedirs(os.path.dirname(self.args.folder + '/metrics.pickle'), exist_ok=True) + with open(f'{self.args.folder}/metrics.pickle', 'wb') as file: + pickle.dump(obj=metric, file=file) + + return metric + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + r""" + Args: + data (Union[str, Iterable]): + The data for prediction. + - a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts. + - a list of instances. + pred (str): + If specified, the predicted results will be saved to the file. Default: ``None``. + lang (str): + Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. + ``None`` if tokenization is not required. + Default: ``None``. + prob (bool): + If ``True``, outputs the probabilities. Default: ``False``. + batch_size (int): + The number of tokens in each batch. Default: 5000. + buckets (int): + The number of buckets that sentences are assigned to. Default: 8. + workers (int): + The number of subprocesses used for data loading. 0 means only the main process. Default: 0. + amp (bool): + Specifies whether to use automatic mixed precision. Default: ``False``. + cache (bool): + If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``. + verbose (bool): + If ``True``, increases the output verbosity. Default: ``True``. + + Returns: + A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``. + """ + + args = self.args.update(locals()) + init_logger(logger, verbose=args.verbose) + + if self.args.use_vq: + self.model.passes_remaining = 0 + self.model.vq.observe_steps_remaining = 0 + + self.transform.eval() + if args.prob: + self.transform.append(Field('probs')) + + #logger.info("Loading the data") + if args.cache: + args.bin = os.path.join(os.path.dirname(args.path), 'bin') + if is_dist(): + batch_size = batch_size // dist.get_world_size() + data = Dataset(self.transform, **args) + data.build(batch_size=batch_size, + n_buckets=buckets, + shuffle=False, + distributed=is_dist(), + even=False, + n_workers=workers) + + #logger.info(f"\n{data}") + + #logger.info("Making predictions on the data") + start = datetime.now() + self.model.eval() + #with tempfile.TemporaryDirectory() as t: + # we have clustered the sentences by length here to speed up prediction, + # so the order of the yielded sentences can't be guaranteed + for batch in progress_bar(data.loader): + #batch, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text, act_dict = self.pred_step(batch) + *predicted_values, = self.pred_step(batch) + #print('429 supar/parser.py ', batch.sentences) + #print(head_preds, deprel_preds, stack_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text) + #logger.info(f"Saving predicted results to {pred}") + # with open(pred, 'w') as f: + + + elapsed = datetime.now() - start + + #if is_dist(): + # dist.barrier() + #tdirs = gather(t) if is_dist() else (t,) + #if pred is not None and is_master(): + #logger.info(f"Saving predicted results to {pred}") + """with open(pred, 'w') as f: + # merge all predictions into one single file + if is_dist() or args.cache: + sentences = (os.path.join(i, s) for i in tdirs for s in os.listdir(i)) + for i in progress_bar(sorted(sentences, key=lambda x: int(os.path.basename(x)))): + with open(i) as s: + shutil.copyfileobj(s, f) + else: + for s in progress_bar(data): + f.write(str(s) + '\n')""" + # exit util all files have been merged + if is_dist(): + dist.barrier() + #logger.info(f"{elapsed}s elapsed, " + # f"{sum(data.sizes)/elapsed.total_seconds():.2f} Tokens/s, " + # f"{len(data)/elapsed.total_seconds():.2f} Sents/s") + + if not cache: + #return data, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text, act_dict + return *predicted_values, + + def backward(self, loss: torch.Tensor, **kwargs): + loss /= self.args.update_steps + if hasattr(self, 'scaler'): + self.scaler.scale(loss).backward(**kwargs) + else: + loss.backward(**kwargs) + + def clip_grad_norm_( + self, + params: Union[Iterable[torch.Tensor], torch.Tensor], + max_norm: float, + norm_type: float = 2 + ) -> torch.Tensor: + self.scaler.unscale_(self.optimizer) + return nn.utils.clip_grad_norm_(params, max_norm, norm_type) + + def clip_grad_value_( + self, + params: Union[Iterable[torch.Tensor], torch.Tensor], + clip_value: float + ) -> None: + self.scaler.unscale_(self.optimizer) + return nn.utils.clip_grad_value_(params, clip_value) + + def reduce(self, obj: Any) -> Any: + if not is_dist(): + return obj + return reduce(obj) + + def train_step(self, batch: Batch) -> torch.Tensor: + ... + + @torch.no_grad() + def eval_step(self, batch: Batch) -> Metric: + ... + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + ... + + def init_optimizer(self) -> Optimizer: + if self.args.encoder in ('lstm', 'transformer'): + optimizer = Adam(params=self.model.parameters(), + lr=self.args.lr, + betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)), + eps=self.args.get('eps', 1e-8), + weight_decay=self.args.get('weight_decay', 0)) + else: + # we found that Huggingface's AdamW is more robust and empirically better than the native implementation + from transformers import AdamW + optimizer = AdamW(params=[{'params': p, 'lr': self.args.lr * (1 if n.startswith('encoder') else self.args.lr_rate)} + for n, p in self.model.named_parameters()], + lr=self.args.lr, + betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)), + eps=self.args.get('eps', 1e-8), + weight_decay=self.args.get('weight_decay', 0)) + return optimizer + + def init_scheduler(self) -> _LRScheduler: + if self.args.encoder == 'lstm': + scheduler = ExponentialLR(optimizer=self.optimizer, + gamma=self.args.decay**(1/self.args.decay_steps)) + elif self.args.encoder == 'transformer': + scheduler = InverseSquareRootLR(optimizer=self.optimizer, + warmup_steps=self.args.warmup_steps) + else: + scheduler = LinearLR(optimizer=self.optimizer, + warmup_steps=self.args.get('warmup_steps', int(self.args.steps*self.args.get('warmup', 0))), + steps=self.args.steps) + return scheduler + + @classmethod + def build(cls, path, **kwargs): + ... + + @classmethod + def load( + cls, + path: str, + reload: bool = True, + src: str = 'github', + checkpoint: bool = False, + **kwargs + ) -> Parser: + r""" + Loads a parser with data fields and pretrained model parameters. + + Args: + path (str): + - a string with the shortcut name of a pretrained model defined in ``supar.MODEL`` + to load from cache or download, e.g., ``'biaffine-dep-en'``. + - a local path to a pretrained model, e.g., ``./<path>/model``. + reload (bool): + Whether to discard the existing cache and force a fresh download. Default: ``False``. + src (str): + Specifies where to download the model. + ``'github'``: github release page. + ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8). + Default: ``'github'``. + checkpoint (bool): + If ``True``, loads all checkpoint states to restore the training process. Default: ``False``. + + Examples: + >>> from supar import Parser + >>> parser = Parser.load('biaffine-dep-en') + >>> parser = Parser.load('./ptb.biaffine.dep.lstm.char') + """ + + args = Config(**locals()) + if not os.path.exists(path): + path = download(supar.MODEL[src].get(path, path), reload=reload) + state = torch.load(path, map_location='cpu', weights_only=False) + #torch.load(path, map_location='cpu') + cls = supar.PARSER[state['name']] if cls.NAME is None else cls + args = state['args'].update(args) + #print('ARGS', args) + model = cls.MODEL(**args) + model.load_pretrained(state['pretrained']) + model.load_state_dict(state['state_dict'], True) + transform = state['transform'] + parser = cls(args, model, transform) + parser.checkpoint_state_dict = state.get('checkpoint_state_dict', None) if checkpoint else None + parser.model.to(parser.device) + return parser + + def save(self, path: str) -> None: + model = self.model + if hasattr(model, 'module'): + model = self.model.module + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + pretrained = state_dict.pop('pretrained.weight', None) + state = {'name': self.NAME, + 'args': model.args, + 'state_dict': state_dict, + 'pretrained': pretrained, + 'transform': self.transform} + torch.save(state, path, pickle_module=dill) + + def save_checkpoint(self, path: str) -> None: + model = self.model + if hasattr(model, 'module'): + model = self.model.module + checkpoint_state_dict = {k: getattr(self, k) for k in ['epoch', 'best_e', 'patience', 'best_metric', 'elapsed']} + checkpoint_state_dict.update({'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'scaler_state_dict': self.scaler.state_dict(), + 'rng_state': get_rng_state()}) + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + pretrained = state_dict.pop('pretrained.weight', None) + state = {'name': self.NAME, + 'args': model.args, + 'state_dict': state_dict, + 'pretrained': pretrained, + 'checkpoint_state_dict': checkpoint_state_dict, + 'transform': self.transform} + torch.save(state, path, pickle_module=dill) diff --git a/tania_scripts/supar/__init__.py b/tania_scripts/supar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..146f3df423731c6ccad8c46a95acf8c26649da6e --- /dev/null +++ b/tania_scripts/supar/__init__.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- + +from .models import (AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos, + BiaffineDependencyParser, + BiaffineSemanticDependencyParser, CRF2oDependencyParser, + CRFConstituencyParser, CRFDependencyParser, SLDependencyParser, ArcEagerDependencyParser, + TetraTaggingConstituencyParser, VIConstituencyParser, SLConstituentParser, + VIDependencyParser, VISemanticDependencyParser) +from .parser import Parser +from .structs import (BiLexicalizedConstituencyCRF, ConstituencyCRF, + ConstituencyLBP, ConstituencyMFVI, Dependency2oCRF, + DependencyCRF, DependencyLBP, DependencyMFVI, + LinearChainCRF, MatrixTree, SemanticDependencyLBP, + SemanticDependencyMFVI, SemiMarkovCRF) + + + +__all__ = ['Parser', + + # Dependency Parsing + 'BiaffineDependencyParser', + 'CRFDependencyParser', + 'CRF2oDependencyParser', + 'VIDependencyParser', + 'SLDependencyParser', + 'ArcEagerDependencyParser', + + # Constituent Parsing + 'AttachJuxtaposeConstituencyParser', + 'CRFConstituencyParser', + 'TetraTaggingConstituencyParser', + 'VIConstituencyParser', + 'SLConstituentParser', + + # Semantic Parsing + 'BiaffineSemanticDependencyParser', + 'VISemanticDependencyParser', + 'LinearChainCRF', + 'SemiMarkovCRF', + + # transforms + 'MatrixTree', + 'DependencyCRF', + 'Dependency2oCRF', + 'ConstituencyCRF', + 'BiLexicalizedConstituencyCRF', + 'DependencyLBP', + 'DependencyMFVI', + 'ConstituencyLBP', + 'ConstituencyMFVI', + 'SemanticDependencyLBP', + 'SemanticDependencyMFVI'] + +__version__ = '1.1.4' + +PARSER = {parser.NAME: parser for parser in [BiaffineDependencyParser, + CRFDependencyParser, + CRF2oDependencyParser, + VIDependencyParser, + SLDependencyParser, + ArcEagerDependencyParser, + AttachJuxtaposeConstituencyParser, + CRFConstituencyParser, + TetraTaggingConstituencyParser, + VIConstituencyParser, + SLConstituentParser, + BiaffineSemanticDependencyParser, + VISemanticDependencyParser]} + +SRC = {'github': 'https://github.com/yzhangcs/parser/releases/download', + 'hlt': 'http://hlt.suda.edu.cn/~yzhang/supar'} +NAME = { + 'biaffine-dep-en': 'ptb.biaffine.dep.lstm.char', + 'biaffine-dep-zh': 'ctb7.biaffine.dep.lstm.char', + 'crf2o-dep-en': 'ptb.crf2o.dep.lstm.char', + 'crf2o-dep-zh': 'ctb7.crf2o.dep.lstm.char', + 'biaffine-dep-roberta-en': 'ptb.biaffine.dep.roberta', + 'biaffine-dep-electra-zh': 'ctb7.biaffine.dep.electra', + 'biaffine-dep-xlmr': 'ud.biaffine.dep.xlmr', + 'crf-con-en': 'ptb.crf.con.lstm.char', + 'crf-con-zh': 'ctb7.crf.con.lstm.char', + 'crf-con-roberta-en': 'ptb.crf.con.roberta', + 'crf-con-electra-zh': 'ctb7.crf.con.electra', + 'crf-con-xlmr': 'spmrl.crf.con.xlmr', + 'biaffine-sdp-en': 'dm.biaffine.sdp.lstm.tag-char-lemma', + 'biaffine-sdp-zh': 'semeval16.biaffine.sdp.lstm.tag-char-lemma', + 'vi-sdp-en': 'dm.vi.sdp.lstm.tag-char-lemma', + 'vi-sdp-zh': 'semeval16.vi.sdp.lstm.tag-char-lemma', + 'vi-sdp-roberta-en': 'dm.vi.sdp.roberta', + 'vi-sdp-electra-zh': 'semeval16.vi.sdp.electra' +} +MODEL = {src: {n: f"{link}/v1.1.0/{m}.zip" for n, m in NAME.items()} for src, link in SRC.items()} +CONFIG = {src: {n: f"{link}/v1.1.0/{m}.ini" for n, m in NAME.items()} for src, link in SRC.items()} + + +def compatible(): + import sys + supar = sys.modules[__name__] + if supar.__version__ < '1.2': + sys.modules['supar.utils.transform'].CoNLL = supar.models.dep.biaffine.transform.CoNLL + sys.modules['supar.utils.transform'].Tree = supar.models.const.crf.transform.Tree + sys.modules['supar.parsers'] = supar.models + sys.modules['supar.parsers.con'] = supar.models.const + + +compatible() diff --git a/tania_scripts/supar/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5c3cf986e39a18d48492fcd3a44278864d4b01b Binary files /dev/null and b/tania_scripts/supar/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56a7800cb9807c641e2b80be758fb03c3bf53e19 Binary files /dev/null and b/tania_scripts/supar/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/__pycache__/__init__.cpython-39.pyc b/tania_scripts/supar/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cfcfd5b0e97b495425c09456acbce3e34f9f610 Binary files /dev/null and b/tania_scripts/supar/__pycache__/__init__.cpython-39.pyc differ diff --git a/tania_scripts/supar/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f9c7f70c814a90cca12c500a2fd2a58d1bca963 Binary files /dev/null and b/tania_scripts/supar/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda67fe0c4d0c1e433a13cbc03f01e01e490e0d9 Binary files /dev/null and b/tania_scripts/supar/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..854f1c47cee59890cdb61da54ee728faf5093ffb Binary files /dev/null and b/tania_scripts/supar/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dcf19a63a0a627a255230554792697d13dc2fb5 Binary files /dev/null and b/tania_scripts/supar/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/__pycache__/vector_quantize.cpython-310.pyc b/tania_scripts/supar/__pycache__/vector_quantize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..821f6b3037a8f6dcf6144badeb248d0ba76a3cf0 Binary files /dev/null and b/tania_scripts/supar/__pycache__/vector_quantize.cpython-310.pyc differ diff --git a/tania_scripts/supar/cmds/__init__.py b/tania_scripts/supar/cmds/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tania_scripts/supar/cmds/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/cmds/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1abebfc6e576b314b0012bd5702c707a2e1f663 Binary files /dev/null and b/tania_scripts/supar/cmds/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/cmds/__pycache__/run.cpython-310.pyc b/tania_scripts/supar/cmds/__pycache__/run.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2391df2e6f812f40aa83be07b9bc14104795013a Binary files /dev/null and b/tania_scripts/supar/cmds/__pycache__/run.cpython-310.pyc differ diff --git a/tania_scripts/supar/cmds/const/aj.py b/tania_scripts/supar/cmds/const/aj.py new file mode 100644 index 0000000000000000000000000000000000000000..7952168e66d32f58ac2e1e66442699b2f7a1fe2e --- /dev/null +++ b/tania_scripts/supar/cmds/const/aj.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import AttachJuxtaposeConstituencyParser +from supar.cmds.run import init + +def main(): + parser = argparse.ArgumentParser(description='Create AttachJuxtapose Constituency Parser.') + parser.set_defaults(Parser=AttachJuxtaposeConstituencyParser) + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file') + subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--use_vq', action='store_true', default=False, help='whether to use vector quantization') + subparser.add_argument('--delay', type=int, default=0) + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + subparser.add_argument('--pred', default='pred.pid', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/const/crf.py b/tania_scripts/supar/cmds/const/crf.py new file mode 100644 index 0000000000000000000000000000000000000000..d3809317c47506101aebe6cc29fee25d7f8e5c07 --- /dev/null +++ b/tania_scripts/supar/cmds/const/crf.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import CRFConstituencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create CRF Constituency Parser.') + parser.set_defaults(Parser=CRFConstituencyParser) + parser.add_argument('--mbr', action='store_true', help='whether to use MBR decoding') + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--implicit', action='store_true', help='whether to conduct implicit binarization') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file') + subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + subparser.add_argument('--pred', default='pred.pid', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/const/sl.py b/tania_scripts/supar/cmds/const/sl.py new file mode 100644 index 0000000000000000000000000000000000000000..97c9d47000b9c6a7465827cd9455f0d44e2503cc --- /dev/null +++ b/tania_scripts/supar/cmds/const/sl.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import SLConstituentParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create SL Constituency Parser.') + parser.set_defaults(Parser=SLConstituentParser) + parser.add_argument('--mbr', action='store_true', help='whether to use MBR decoding') + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--implicit', action='store_true', help='whether to conduct implicit binarization') + subparser.add_argument('--max_len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file') + subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--use_vq', action='store_true', default=False, help='whether to use vector quantization') + subparser.add_argument('--decoder', choices=['mlp', 'lstm'], default='mlp', help='incremental decoder to use') + subparser.add_argument('--codes', choices=['abs', 'rel'], default=None, help='sequential encoding used') + subparser.add_argument('--delay', type=int, default=0) + subparser.add_argument('--root_node', type=str, default='S') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + subparser.add_argument('--pred', default='pred.pid', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/const/tt.py b/tania_scripts/supar/cmds/const/tt.py new file mode 100644 index 0000000000000000000000000000000000000000..286c402ec1d85ea7b6f648906438ed1a469551ee --- /dev/null +++ b/tania_scripts/supar/cmds/const/tt.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import TetraTaggingConstituencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create Tetra-tagging Constituency Parser.') + parser.set_defaults(Parser=TetraTaggingConstituencyParser) + parser.add_argument('--depth', default=8, type=int, help='stack depth') + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file') + subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + subparser.add_argument('--pred', default='pred.pid', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/const/vi.py b/tania_scripts/supar/cmds/const/vi.py new file mode 100644 index 0000000000000000000000000000000000000000..0b63a3b3ed207feffe3939a09f026041a47a8520 --- /dev/null +++ b/tania_scripts/supar/cmds/const/vi.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import VIConstituencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create Constituency Parser using Variational Inference.') + parser.set_defaults(Parser=VIConstituencyParser) + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--implicit', action='store_true', help='whether to conduct implicit binarization') + subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file') + subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use') + subparser.add_argument('--inference', default='mfvi', choices=['mfvi', 'lbp'], help='approximate inference methods') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + subparser.add_argument('--pred', default='pred.pid', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/dep/__pycache__/biaffine.cpython-310.pyc b/tania_scripts/supar/cmds/dep/__pycache__/biaffine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d1e880285e3ad83ee46efd54d2f67696d3632f8 Binary files /dev/null and b/tania_scripts/supar/cmds/dep/__pycache__/biaffine.cpython-310.pyc differ diff --git a/tania_scripts/supar/cmds/dep/__pycache__/eager.cpython-310.pyc b/tania_scripts/supar/cmds/dep/__pycache__/eager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad57c97660122f9abeadd858036cba5f3669e9cd Binary files /dev/null and b/tania_scripts/supar/cmds/dep/__pycache__/eager.cpython-310.pyc differ diff --git a/tania_scripts/supar/cmds/dep/__pycache__/sl.cpython-310.pyc b/tania_scripts/supar/cmds/dep/__pycache__/sl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce6e7c7bc3a266002f7b31faf3e1bfbd4c6f6aef Binary files /dev/null and b/tania_scripts/supar/cmds/dep/__pycache__/sl.cpython-310.pyc differ diff --git a/tania_scripts/supar/cmds/dep/biaffine.py b/tania_scripts/supar/cmds/dep/biaffine.py new file mode 100644 index 0000000000000000000000000000000000000000..315ae6e89bf7e217c8f6c6ffc8f1af9c170b3b7b --- /dev/null +++ b/tania_scripts/supar/cmds/dep/biaffine.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import BiaffineDependencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create Biaffine Dependency Parser.') + parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness') + parser.add_argument('--proj', action='store_true', help='whether to projectivize the data') + parser.add_argument('--partial', action='store_true', help='whether partial annotation is included') + parser.set_defaults(Parser=BiaffineDependencyParser) + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file') + subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/dep/crf.py b/tania_scripts/supar/cmds/dep/crf.py new file mode 100644 index 0000000000000000000000000000000000000000..1229ae1fb03c7351a0f4f403b7319d2dcc840ec1 --- /dev/null +++ b/tania_scripts/supar/cmds/dep/crf.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import CRFDependencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create first-order CRF Dependency Parser.') + parser.set_defaults(Parser=CRFDependencyParser) + parser.add_argument('--mbr', action='store_true', help='whether to use MBR decoding') + parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness') + parser.add_argument('--proj', action='store_true', help='whether to projectivize the data') + parser.add_argument('--partial', action='store_true', help='whether partial annotation is included') + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file') + subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/dep/crf2o.py b/tania_scripts/supar/cmds/dep/crf2o.py new file mode 100644 index 0000000000000000000000000000000000000000..cc066ec057258b2c4c48c4d51b6b9bcf90044a7d --- /dev/null +++ b/tania_scripts/supar/cmds/dep/crf2o.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import CRF2oDependencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create second-order CRF Dependency Parser.') + parser.set_defaults(Parser=CRF2oDependencyParser) + parser.add_argument('--mbr', action='store_true', help='whether to use MBR decoding') + parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness') + parser.add_argument('--proj', action='store_true', help='whether to projectivize the data') + parser.add_argument('--partial', action='store_true', help='whether partial annotation is included') + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file') + subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/dep/eager.py b/tania_scripts/supar/cmds/dep/eager.py new file mode 100644 index 0000000000000000000000000000000000000000..74ae87918887bedbf673cc9dafa5b99f9de5de3b --- /dev/null +++ b/tania_scripts/supar/cmds/dep/eager.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import ArcEagerDependencyParser +from supar.cmds.run import init + +def main(): + parser = argparse.ArgumentParser(description='Create Transition Dependency Parser.') + parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness') + parser.add_argument('--proj', action='store_true', help='whether to projectivize the data') + parser.add_argument('--partial', action='store_true', help='whether partial annotation is included') + parser.set_defaults(Parser=ArcEagerDependencyParser) + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file') + subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--use_vq', action='store_true', default=False, help='whether to use vector quantization') + subparser.add_argument('--decoder', choices=['mlp', 'lstm'], default='mlp', help='incremental decoder to use') + subparser.add_argument('--delay', type=int, default=0) + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/dep/sl.py b/tania_scripts/supar/cmds/dep/sl.py new file mode 100644 index 0000000000000000000000000000000000000000..97f852ba600e150a2a64b3354b1f416c86b91868 --- /dev/null +++ b/tania_scripts/supar/cmds/dep/sl.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import SLDependencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create SL Dependency Parsing Parser.') + parser.set_defaults(Parser=SLDependencyParser) + parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness') + parser.add_argument('--proj', action='store_true', help='whether to projectivize the data') + parser.add_argument('--partial', action='store_true', help='whether partial annotation is included') + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--implicit', action='store_true', help='whether to conduct implicit binarization') + subparser.add_argument('--max_len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file') + subparser.add_argument('--embed', default=None, help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--use_vq', action='store_true', default=False, help='whether to use vector quantization') + subparser.add_argument('--decoder', choices=['mlp', 'lstm'], default='mlp', help='incremental decoder to use') + subparser.add_argument('--codes', choices=['abs', 'rel', 'pos', '1p', '2p'], default=None, help='SL coding used') + subparser.add_argument('--delay', type=int, default=0) + subparser.add_argument('--root_node', type=str, default='S') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.pid', help='path to dataset') + subparser.add_argument('--pred', default='pred.pid', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/dep/vi.py b/tania_scripts/supar/cmds/dep/vi.py new file mode 100644 index 0000000000000000000000000000000000000000..1175977b10b39c97cd28321e947dac09c7211c7b --- /dev/null +++ b/tania_scripts/supar/cmds/dep/vi.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import VIDependencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create Dependency Parser using Variational Inference.') + parser.add_argument('--tree', action='store_true', help='whether to ensure well-formedness') + parser.add_argument('--proj', action='store_true', help='whether to projectivise the data') + parser.add_argument('--partial', action='store_true', help='whether partial annotation is included') + parser.set_defaults(Parser=VIDependencyParser) + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'elmo', 'bert'], nargs='*', help='features to use') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file') + subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file') + subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file') + subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use') + subparser.add_argument('--inference', default='mfvi', choices=['mfvi', 'lbp'], help='approximate inference methods') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--punct', action='store_true', help='whether to include punctuation') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/ptb/test.conllx', help='path to dataset') + subparser.add_argument('--pred', default='pred.conllx', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/run.py b/tania_scripts/supar/cmds/run.py new file mode 100644 index 0000000000000000000000000000000000000000..df93b7227bc82c3990b8bde877c23bb09e432dcd --- /dev/null +++ b/tania_scripts/supar/cmds/run.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- + +import os, shutil +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from supar.utils import Config +from supar.utils.logging import init_logger, logger +from supar.utils.parallel import get_device_count, get_free_port + + +def init(parser): + parser.add_argument('--path', '-p', help='path to model file') + parser.add_argument('--conf', '-c', default='', help='path to config file') + parser.add_argument('--device', '-d', default='0', help='ID of GPU to use') + parser.add_argument('--seed', '-s', default=1, type=int, help='seed for generating random numbers') + parser.add_argument('--threads', '-t', default=16, type=int, help='num of threads') + parser.add_argument('--workers', '-w', default=0, type=int, help='num of processes used for data loading') + parser.add_argument('--cache', action='store_true', help='cache the data for fast loading') + parser.add_argument('--binarize', action='store_true', help='binarize the data first') + parser.add_argument('--amp', action='store_true', help='use automatic mixed precision for parsing') + parser.add_argument('--dist', choices=['ddp', 'fsdp'], default='ddp', help='distributed training types') + parser.add_argument('--wandb', action='store_true', help='wandb for tracking experiments') + args, unknown = parser.parse_known_args() + args, unknown = parser.parse_known_args(unknown, args) + args = Config.load(**vars(args), unknown=unknown) + + args.folder = '/'.join(args.path.split('/')[:-1]) + if not os.path.exists(args.folder): + os.makedirs(args.folder) + + os.environ['CUDA_VISIBLE_DEVICES'] = args.device + if get_device_count() > 1: + os.environ['MASTER_ADDR'] = 'tcp://localhost' + os.environ['MASTER_PORT'] = get_free_port() + mp.spawn(parse, args=(args,), nprocs=get_device_count()) + else: + parse(0 if torch.cuda.is_available() else -1, args) + + +def parse(local_rank, args): + Parser = args.pop('Parser') + torch.set_num_threads(args.threads) + torch.manual_seed(args.seed) + if get_device_count() > 1: + dist.init_process_group(backend='nccl', + init_method=f"{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}", + world_size=get_device_count(), + rank=local_rank) + torch.cuda.set_device(local_rank) + # init logger after dist has been initialized + init_logger(logger, f"{args.path}.{args.mode}.log", 'a' if args.get('checkpoint') else 'w') + logger.info('\n' + str(args)) + + args.local_rank = local_rank + os.environ['RANK'] = os.environ['LOCAL_RANK'] = f'{local_rank}' + if args.mode == 'train': + parser = Parser.load(**args) if args.checkpoint else Parser.build(**args) + parser.train(**args) + elif args.mode == 'evaluate': + parser = Parser.load(**args) + parser.evaluate(**args) + elif args.mode == 'predict': + parser = Parser.load(**args) + parser.predict(**args) diff --git a/tania_scripts/supar/cmds/sdp/biaffine.py b/tania_scripts/supar/cmds/sdp/biaffine.py new file mode 100644 index 0000000000000000000000000000000000000000..a36ab6a40c4ccba8abe47ea32732e3efc468c398 --- /dev/null +++ b/tania_scripts/supar/cmds/sdp/biaffine.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import BiaffineSemanticDependencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create Biaffine Semantic Dependency Parser.') + parser.set_defaults(Parser=BiaffineSemanticDependencyParser) + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'lemma', 'elmo', 'bert'], nargs='*', help='features to use') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/sdp/DM/train.conllu', help='path to train file') + subparser.add_argument('--dev', default='data/sdp/DM/dev.conllu', help='path to dev file') + subparser.add_argument('--test', default='data/sdp/DM/test.conllu', help='path to test file') + subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--n-embed-proj', default=125, type=int, help='dimension of projected embeddings') + subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset') + subparser.add_argument('--pred', default='pred.conllu', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/cmds/sdp/vi.py b/tania_scripts/supar/cmds/sdp/vi.py new file mode 100644 index 0000000000000000000000000000000000000000..26fee77c40d1475e93f12c187ed66f8f3810aac1 --- /dev/null +++ b/tania_scripts/supar/cmds/sdp/vi.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- + +import argparse + +from supar import VISemanticDependencyParser +from supar.cmds.run import init + + +def main(): + parser = argparse.ArgumentParser(description='Create Semantic Dependency Parser using Variational Inference.') + parser.set_defaults(Parser=VISemanticDependencyParser) + subparsers = parser.add_subparsers(title='Commands', dest='mode') + # train + subparser = subparsers.add_parser('train', help='Train a parser.') + subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'lemma', 'elmo', 'bert'], nargs='*', help='features to use') + subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') + subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') + subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use') + subparser.add_argument('--max-len', type=int, help='max length of the sentences') + subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') + subparser.add_argument('--train', default='data/sdp/DM/train.conllu', help='path to train file') + subparser.add_argument('--dev', default='data/sdp/DM/dev.conllu', help='path to dev file') + subparser.add_argument('--test', default='data/sdp/DM/test.conllu', help='path to test file') + subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`') + subparser.add_argument('--n-embed-proj', default=125, type=int, help='dimension of projected embeddings') + subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use') + subparser.add_argument('--inference', default='mfvi', choices=['mfvi', 'lbp'], help='approximate inference methods') + # evaluate + subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset') + # predict + subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') + subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') + subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset') + subparser.add_argument('--pred', default='pred.conllu', help='path to predicted result') + subparser.add_argument('--prob', action='store_true', help='whether to output probs') + init(parser) + + +if __name__ == "__main__": + main() diff --git a/tania_scripts/supar/codelin/__init__.py b/tania_scripts/supar/codelin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..847de4fc4f73b6f0f337f356df357cbfbe570bf1 --- /dev/null +++ b/tania_scripts/supar/codelin/__init__.py @@ -0,0 +1,34 @@ +from .encs.enc_deps import D_NaiveAbsoluteEncoding, D_NaiveRelativeEncoding, D_PosBasedEncoding, D_BrkBasedEncoding, D_Brk2PBasedEncoding +from .encs.enc_const import C_NaiveAbsoluteEncoding, C_NaiveRelativeEncoding +from .utils.constants import D_2P_GREED, D_2P_PROP + +# import structures for encoding/decoding +from .models.const_label import C_Label +from .models.const_tree import C_Tree +from .models.linearized_tree import LinearizedTree +from .models.deps_label import D_Label +from .models.deps_tree import D_Tree + +LABEL_SEPARATOR = '€' +UNARY_JOINER = '@' + +def get_con_encoder(encoding: str, sep: str = LABEL_SEPARATOR, unary_joiner: str = UNARY_JOINER): + if encoding == 'abs': + return C_NaiveAbsoluteEncoding(sep, unary_joiner) + elif encoding == 'rel': + return C_NaiveRelativeEncoding(sep, unary_joiner) + return NotImplementedError + +def get_dep_encoder(encoding: str, sep: str, displacement: bool = False): + if encoding == 'abs': + return D_NaiveAbsoluteEncoding(sep) + elif encoding == 'rel': + return D_NaiveRelativeEncoding(sep, hang_from_root=False) + elif encoding == 'pos': + return D_PosBasedEncoding(sep) + elif encoding == '1p': + return D_BrkBasedEncoding(sep, displacement) + elif encoding == '2p': + return D_Brk2PBasedEncoding(sep, displacement, D_2P_PROP) + return NotImplementedError + diff --git a/tania_scripts/supar/codelin/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21b03f9cdf5de941345a04314d4e726119e9dfd2 Binary files /dev/null and b/tania_scripts/supar/codelin/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..737b648aa6c9eb05010e8759390d493979dff7fc Binary files /dev/null and b/tania_scripts/supar/codelin/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-310.pyc b/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..163660a55c0e295e147a3eb4645efc78e2a5e088 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-311.pyc b/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7faea83a2faeba0be37ddc0648b2989d60fcc074 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/__pycache__/abstract_encoding.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/abstract_encoding.py b/tania_scripts/supar/codelin/encs/abstract_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3fa035e7f5883068f9e641d164e366cc0c37a4 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/abstract_encoding.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod + +class ADEncoding(ABC): + ''' + Abstract class for Dependency Encodings + Sets the main constructor method and defines the methods + - Encode + - Decode + When adding a new Dependency Encoding it must extend this class + and implement those methods + ''' + def __init__(self, separator): + self.separator = separator + + def encode(self, nodes): + pass + def decode(self, labels, postags, words): + pass + +class ACEncoding(ABC): + ''' + Abstract class for Constituent Encodings + Sets the main constructor method and defines the abstract methods + - Encode + - Decode + When adding a new Constituent Encoding it must extend this class + and implement those methods + ''' + def __init__(self, separator, ujoiner): + self.separator = separator + self.unary_joiner = ujoiner + + def encode(self, constituent_tree): + pass + def decode(self, linearized_tree): + pass + diff --git a/tania_scripts/supar/codelin/encs/constituent.py b/tania_scripts/supar/codelin/encs/constituent.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbe8f0b18270ebf769078c825e3441e22acd766 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/constituent.py @@ -0,0 +1,115 @@ +from src.models.linearized_tree import LinearizedTree +from src.encs.enc_const import * +from src.utils.extract_feats import extract_features_const +from src.utils.constants import C_INCREMENTAL_ENCODING, C_ABSOLUTE_ENCODING, C_RELATIVE_ENCODING, C_DYNAMIC_ENCODING + +import stanza.pipeline +from supar.codelin.models.const_tree import C_Tree + + +## Encoding and decoding + +def encode_constituent(in_path, out_path, encoding_type, separator, unary_joiner, features): + ''' + Encodes the selected file according to the specified parameters: + :param in_path: Path of the file to be encoded + :param out_path: Path where to write the encoded labels + :param encoding_type: Encoding used + :param separator: string used to separate label fields + :param unary_joiner: string used to separate nodes from unary chains + :param features: features to add as columns to the labels file + ''' + + if encoding_type == C_ABSOLUTE_ENCODING: + encoder = C_NaiveAbsoluteEncoding(separator, unary_joiner) + elif encoding_type == C_RELATIVE_ENCODING: + encoder = C_NaiveRelativeEncoding(separator, unary_joiner) + elif encoding_type == C_DYNAMIC_ENCODING: + encoder = C_NaiveDynamicEncoding(separator, unary_joiner) + elif encoding_type == C_INCREMENTAL_ENCODING: + encoder = C_NaiveIncrementalEncoding(separator, unary_joiner) + else: + raise Exception("Unknown encoding type") + + # build feature index dictionary + f_idx_dict = {} + if features: + if features == ["ALL"]: + features = extract_features_const(in_path) + i=0 + for f in features: + f_idx_dict[f]=i + i+=1 + + file_out = open(out_path, "w") + file_in = open(in_path, "r") + + tree_counter = 0 + labels_counter = 0 + label_set = set() + + for line in file_in: + line = line.rstrip() + tree = C_Tree.from_string(line) + linearized_tree = encoder.encode(tree) + file_out.write(linearized_tree.to_string(f_idx_dict)) + file_out.write("\n") + tree_counter += 1 + labels_counter += len(linearized_tree) + for lbl in linearized_tree.get_labels(): + label_set.add(str(lbl)) + + return labels_counter, tree_counter, len(label_set) + +def decode_constituent(in_path, out_path, encoding_type, separator, unary_joiner, conflicts, nulls, postags, lang): + ''' + Decodes the selected file according to the specified parameters: + :param in_path: Path of the labels file to be decoded + :param out_path: Path where to write the decoded tree + :param encoding_type: Encoding used + :param separator: string used to separate label fields + :param unary_joiner: string used to separate nodes from unary chains + :param conflicts: conflict resolution heuristics to apply + ''' + + if encoding_type == C_ABSOLUTE_ENCODING: + decoder = C_NaiveAbsoluteEncoding(separator, unary_joiner) + elif encoding_type == C_RELATIVE_ENCODING: + decoder = C_NaiveRelativeEncoding(separator, unary_joiner) + elif encoding_type == C_DYNAMIC_ENCODING: + decoder = C_NaiveDynamicEncoding(separator, unary_joiner) + elif encoding_type == C_INCREMENTAL_ENCODING: + decoder = C_NaiveIncrementalEncoding(separator, unary_joiner) + else: + raise Exception("Unknown encoding type") + + if postags: + stanza.download(lang=lang) + nlp = stanza.Pipeline(lang=lang, processors='tokenize, pos') + + f_in = open(in_path) + f_out = open(out_path,"w+") + + tree_string = "" + labels_counter = 0 + tree_counter = 0 + + for line in f_in: + if line == "\n": + tree_string = tree_string.rstrip() + current_tree = LinearizedTree.from_string(tree_string, mode="CONST", separator=separator, unary_joiner=unary_joiner) + + if postags: + c_tags = nlp(current_tree.get_sentence()) + current_tree.set_postags([word.pos for word in c_tags]) + + decoded_tree = decoder.decode(current_tree) + decoded_tree = decoded_tree.postprocess_tree(conflicts, nulls) + + f_out.write(str(decoded_tree).replace('\n','')+'\n') + tree_string = "" + tree_counter+=1 + tree_string += line + labels_counter += 1 + + return tree_counter, labels_counter \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/dependency.py b/tania_scripts/supar/codelin/encs/dependency.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a49d41dbc2f038812142ee05741099e8460a28 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/dependency.py @@ -0,0 +1,129 @@ +import stanza +from supar.codelin.models.linearized_tree import LinearizedTree +from supar.codelin.models.deps_label import D_Label +from supar.codelin.utils.extract_feats import extract_features_deps +from supar.codelin.encs.enc_deps import * +from supar.codelin.utils.constants import * +from supar.codelin.models.deps_tree import D_Tree + +# Encoding +def encode_dependencies(in_path, out_path, encoding_type, separator, displacement, planar_alg, root_enc, features): + ''' + Encodes the selected file according to the specified parameters: + :param in_path: Path of the file to be encoded + :param out_path: Path where to write the encoded labels + :param encoding_type: Encoding used + :param separator: string used to separate label fields + :param displacement: boolean to indicate if use displacement in bracket based encodings + :param planar_alg: string used to choose the plane separation algorithm + :param features: features to add as columns to the labels file + ''' + + # Create the encoder + if encoding_type == D_ABSOLUTE_ENCODING: + encoder = D_NaiveAbsoluteEncoding(separator) + elif encoding_type == D_RELATIVE_ENCODING: + encoder = D_NaiveRelativeEncoding(separator, root_enc) + elif encoding_type == D_POS_ENCODING: + encoder = D_PosBasedEncoding(separator) + elif encoding_type == D_BRACKET_ENCODING: + encoder = D_BrkBasedEncoding(separator, displacement) + elif encoding_type == D_BRACKET_ENCODING_2P: + encoder = D_Brk2PBasedEncoding(separator, displacement, planar_alg) + else: + raise Exception("Unknown encoding type") + + f_idx_dict = {} + if features: + if features == ["ALL"]: + features = extract_features_deps(in_path) + i=0 + for f in features: + f_idx_dict[f]=i + i+=1 + + file_out = open(out_path,"w+") + label_set = set() + tree_counter = 0 + label_counter = 0 + trees = D_Tree.read_conllu_file(in_path, filter_projective=False) + + for t in trees: + # encode labels + linearized_tree = encoder.encode(t) + file_out.write(linearized_tree.to_string(f_idx_dict)) + file_out.write("\n") + + tree_counter+=1 + label_counter+=len(linearized_tree) + + for lbl in linearized_tree.get_labels(): + label_set.add(str(lbl)) + + print('/supar/codelin/encs/dependency.py') + return tree_counter, label_counter, len(label_set) + +# Decoding + +def decode_dependencies(in_path, out_path, encoding_type, separator, displacement, multiroot, root_search, root_enc, postags, lang): + ''' + Decodes the selected file according to the specified parameters: + :param in_path: Path of the file to be encoded + :param out_path: Path where to write the encoded labels + :param encoding_type: Encoding used + :param separator: string used to separate label fields + :param displacement: boolean to indicate if use displacement in bracket based encodings + :param multiroot: boolean to indicate if multiroot conll trees are allowed + :param root_search: strategy to select how to search the root if no root found in decoded tree + ''' + + if encoding_type == D_ABSOLUTE_ENCODING: + decoder = D_NaiveAbsoluteEncoding(separator) + elif encoding_type == D_RELATIVE_ENCODING: + decoder = D_NaiveRelativeEncoding(separator, root_enc) + elif encoding_type == D_POS_ENCODING: + decoder = D_PosBasedEncoding(separator) + elif encoding_type == D_BRACKET_ENCODING: + decoder = D_BrkBasedEncoding(separator, displacement) + elif encoding_type == D_BRACKET_ENCODING_2P: + decoder = D_Brk2PBasedEncoding(separator, displacement, None) + else: + raise Exception("Unknown encoding type") + + f_in=open(in_path) + f_out=open(out_path,"w+") + + tree_counter=0 + labels_counter=0 + + tree_string = "" + + print('/supar/codelin/encs/dependency.py 101') + + if postags: + stanza.download(lang=lang) + nlp = stanza.Pipeline(lang=lang, processors='tokenize,pos') + + + for line in f_in: + if line == "\n": + tree_string = tree_string.rstrip() + current_tree = LinearizedTree.from_string(tree_string, mode="DEPS", separator=separator) + + if postags: + c_tags = nlp(current_tree.get_sentence()) + current_tree.set_postags([word.pos for word in c_tags]) + + decoded_tree = decoder.decode(current_tree) + decoded_tree.postprocess_tree(root_search, multiroot) + f_out.write("# text = "+decoded_tree.get_sentence()+"\n") + f_out.write(str(decoded_tree)) + + tree_string = "" + tree_counter+=1 + + tree_string += line + labels_counter += 1 + + + return tree_counter, labels_counter diff --git a/tania_scripts/supar/codelin/encs/enc_const/__init__.py b/tania_scripts/supar/codelin/encs/enc_const/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d803a3b0a700fe22bdc565435e46c84b885edf38 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_const/__init__.py @@ -0,0 +1,4 @@ +from .naive_absolute import C_NaiveAbsoluteEncoding +from .naive_relative import C_NaiveRelativeEncoding +from .naive_dynamic import C_NaiveDynamicEncoding +from .naive_incremental import C_NaiveIncrementalEncoding \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41d16cbfa21181d065d0c1f0e255961bc5fc54b8 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12f0189eb2fabe6bb71d622e2452f312b08a7c8c Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e654e52ff2f574d7c86285d440042167471c068d Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb3ffaca5b999a39848ef14e042cb37fd4f4ebae Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_absolute.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d6495f2c7305a9257f7a536493cc1034c193804 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30ad5b6badac7cd3836fb4078ae00b30e8c29bd2 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_dynamic.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0df5670f680c988a0e1be47f68f977f4e1915c30 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f6374cacf05ac756434ee357b9b18a208fcc681 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_incremental.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58ce1fee263b5b20fbc8fe069e0ca2f9155706e4 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b42a4c47d2c9a3e7745917396a876bf0af990d50 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_const/__pycache__/naive_relative.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_const/naive_absolute.py b/tania_scripts/supar/codelin/encs/enc_const/naive_absolute.py new file mode 100644 index 0000000000000000000000000000000000000000..5dbb9844109e33406fb7056bdfe3e4e6e2e07c17 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_const/naive_absolute.py @@ -0,0 +1,144 @@ +from supar.codelin.encs.abstract_encoding import ACEncoding +from supar.codelin.utils.constants import C_ABSOLUTE_ENCODING, C_ROOT_LABEL, C_CONFLICT_SEPARATOR, C_NONE_LABEL +from supar.codelin.models.const_label import C_Label +from supar.codelin.models.linearized_tree import LinearizedTree +from supar.codelin.models.const_tree import C_Tree + +import re + +class C_NaiveAbsoluteEncoding(ACEncoding): + def __init__(self, separator, unary_joiner): + self.separator = separator + self.unary_joiner = unary_joiner + + def __str__(self): + return "Constituent Naive Absolute Encoding" + + def encode(self, constituent_tree): + lc_tree = LinearizedTree.empty_tree() + leaf_paths = constituent_tree.path_to_leaves(collapse_unary=True, unary_joiner=self.unary_joiner) + + for i in range(0, len(leaf_paths)-1): + path_a = leaf_paths[i] + path_b = leaf_paths[i+1] + + last_common = "" + n_commons = 0 + for a,b in zip(path_a, path_b): + + if (a!=b): + # Remove the digits and aditional feats in the last common node + last_common = re.sub(r'[0-9]+', '', last_common) + last_common = last_common.split("##")[0] + + # Get word and POS tag + word = path_a[-1] + postag = path_a[-2] + + # Build the Leaf Unary Chain + unary_chain = None + leaf_unary_chain = postag.split(self.unary_joiner) + if len(leaf_unary_chain)>1: + unary_list = [] + for element in leaf_unary_chain[:-1]: + unary_list.append(element.split("##")[0]) + + unary_chain = self.unary_joiner.join(unary_list) + postag = leaf_unary_chain[len(leaf_unary_chain)-1] + + # Clean the POS Tag and extract additional features + postag_split = postag.split("##") + feats = [None] + + if len(postag_split) > 1: + postag = re.sub(r'[0-9]+', '', postag_split[0]) + feats = postag_split[1].split("|") + else: + postag = re.sub(r'[0-9]+', '', postag) + + c_label= C_Label(n_commons, last_common, unary_chain, C_ABSOLUTE_ENCODING, + self.separator, self.unary_joiner) + + # Append the data + lc_tree.add_row(word, postag, feats, c_label) + + break + + n_commons += len(a.split(self.unary_joiner)) + last_common = a + + # n = max number of features of the tree + lc_tree.n = max([len(f) for f in lc_tree.additional_feats]) + return lc_tree + + def decode(self, linearized_tree): + # Check valid labels + if not linearized_tree: + print("[!] Error while decoding: Null tree.") + return + + # Create constituent tree + tree = C_Tree(C_ROOT_LABEL, []) + current_level = tree + + old_n_commons = 0 + old_level = None + for word, postag, feats, label in linearized_tree.iterrows(): + + # Descend through the tree until reach the level indicated by last_common + current_level = tree + for level_index in range(label.n_commons): + if (current_level.is_terminal()) or (level_index >= old_n_commons): + current_level.add_child(C_Tree(C_NONE_LABEL, [])) + + current_level = current_level.r_child() + + # Split the Last Common field of the Label in case it has a Unary Chain Collapsed + label.last_common = label.last_common.split(self.unary_joiner) + + if len(label.last_common) == 1: + # If current level has no label yet, put the label + # If current level has label but different than this one, set it as a conflict + if (current_level.label == C_NONE_LABEL): + current_level.label = label.last_common[0].rstrip() + else: + current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[0] + if len(label.last_common)>1: + current_level = tree + + # problem when n_commons predicted is LESS than the number of last commons predicted + descend_levels = label.n_commons - (len(label.last_common)) + 1 + for level_index in range(descend_levels): + current_level = current_level.r_child() + for i in range(len(label.last_common)-1): + if (current_level.label == C_NONE_LABEL): + current_level.label = label.last_common[i] + else: + current_level.label = current_level.label+C_CONFLICT_SEPARATOR+label.last_common[i] + + if len(current_level.children)>0: + current_level = current_level.r_child() + + # If we reach a POS tag, set it as child of the current chain + if current_level.is_preterminal(): + temp_current_level_children = current_level.children + current_level.label = label.last_common[i+1] + current_level.children = temp_current_level_children + for c in temp_current_level_children: + c.parent = current_level + else: + current_level.label=label.last_common[i+1] + + + # Fill POS tag in this node or previous one + if (label.n_commons >= old_n_commons): + current_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner) + else: + old_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner) + + old_n_commons=label.n_commons + old_level=current_level + + tree.inherit_tree() + + return tree \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/enc_const/naive_dynamic.py b/tania_scripts/supar/codelin/encs/enc_const/naive_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..c4dc11f5ce16a283057c0852ca4062d21e92e8b6 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_const/naive_dynamic.py @@ -0,0 +1,166 @@ +from supar.codelin.encs.abstract_encoding import ACEncoding +from supar.codelin.utils.constants import C_ABSOLUTE_ENCODING, C_RELATIVE_ENCODING, C_ROOT_LABEL, C_CONFLICT_SEPARATOR, C_NONE_LABEL +from supar.codelin.models.const_label import C_Label +from supar.codelin.models.linearized_tree import LinearizedTree +from supar.codelin.models.const_tree import C_Tree + +import re + +class C_NaiveDynamicEncoding(ACEncoding): + def __init__(self, separator, unary_joiner): + self.separator = separator + self.unary_joiner = unary_joiner + + def __str__(self): + return "Constituent Naive Dynamic Encoding" + + def encode(self, constituent_tree): + lc_tree = LinearizedTree.empty_tree() + leaf_paths = constituent_tree.path_to_leaves(collapse_unary=True, unary_joiner=self.unary_joiner) + + last_n_common=0 + for i in range(0, len(leaf_paths)-1): + path_a=leaf_paths[i] + path_b=leaf_paths[i+1] + + last_common="" + n_commons=0 + for a,b in zip(path_a, path_b): + + if (a!=b): + # Remove the digits and aditional feats in the last common node + last_common = re.sub(r'[0-9]+', '', last_common) + last_common = last_common.split("##")[0] + + # Get word and POS tag + word = path_a[-1] + postag = path_a[-2] + + # Build the Leaf Unary Chain + unary_chain = None + leaf_unary_chain = postag.split(self.unary_joiner) + if len(leaf_unary_chain)>1: + unary_list = [] + for element in leaf_unary_chain[:-1]: + unary_list.append(element.split("##")[0]) + + unary_chain = self.unary_joiner.join(unary_list) + postag = leaf_unary_chain[len(leaf_unary_chain)-1] + + # Clean the POS Tag and extract additional features + postag_split = postag.split("##") + feats = [None] + + if len(postag_split) > 1: + postag = re.sub(r'[0-9]+', '', postag_split[0]) + feats = postag_split[1].split("|") + else: + postag = re.sub(r'[0-9]+', '', postag) + + # Compute the encoded value + abs_val=n_commons + rel_val=(n_commons-last_n_common) + + if (abs_val<=3 and rel_val<=-2): + c_label = (C_Label(abs_val, last_common, unary_chain, C_ABSOLUTE_ENCODING, self.separator, self.unary_joiner)) + else: + c_label = (C_Label(rel_val, last_common, unary_chain, C_RELATIVE_ENCODING, self.separator, self.unary_joiner)) + + lc_tree.add_row(word, postag, feats, c_label) + + last_n_common=n_commons + break + + # Store Last Common and increase n_commons + # Note: When increasing n_commons use the number from split the collapsed chains + n_commons += len(a.split(self.unary_joiner)) + last_common = a + + # n = max number of features of the tree + lc_tree.n = max([len(f) for f in lc_tree.additional_feats]) + return lc_tree + + def decode(self, linearized_tree): + # Check valid labels + if not linearized_tree: + print("[*] Error while decoding: Null tree.") + return + + # Create constituent tree + tree = C_Tree(C_ROOT_LABEL, []) + current_level = tree + + old_n_commons=0 + old_last_common='' + old_level=None + is_first = True + last_label = None + + for word, postag, feats, label in linearized_tree.iterrows(): + + # Convert the labels to absolute scale + if last_label!=None and label.encoding_type==C_RELATIVE_ENCODING: + label.to_absolute(last_label) + + # First label must have a positive n_commons value + if is_first and label.n_commons <= 0: + label.n_commons = 1 + + # Descend through the tree until reach the level indicated by last_common + current_level = tree + for level_index in range(label.n_commons): + if (len(current_level.children)==0) or (level_index >= old_n_commons): + current_level.add_child(C_Tree(C_NONE_LABEL, [])) + + current_level = current_level.r_child() + + # Split the Last Common field of the Label in case it has a Unary Chain Collapsed + label.last_common = label.last_common.split(self.unary_joiner) + + if len(label.last_common)==1: + # If current level has no label yet, put the label + # If current level has label but different than this one, set it as a conflict + if (current_level.label==C_NONE_LABEL): + current_level.label = label.last_common[0].rstrip() + else: + current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[0] + if len(label.last_common)>1: + current_level = tree + + # problem when n_commons predicted is LESS than the number of last commons predicted + descend_levels = label.n_commons - (len(label.last_common)) + 1 + for level_index in range(descend_levels): + current_level = current_level.r_child() + for i in range(len(label.last_common)-1): + if (current_level.label == C_NONE_LABEL): + current_level.label = label.last_common[i] + else: + current_level.label = current_level.label+C_CONFLICT_SEPARATOR+label.last_common[i] + + if len(current_level.children)>0: + current_level = current_level.r_child() + + # If we reach a POS tag, set it as child of the current chain + if current_level.is_preterminal(): + temp_current_level_children = current_level.children + current_level.label = label.last_common[i+1] + current_level.children = temp_current_level_children + for c in temp_current_level_children: + c.parent = current_level + else: + current_level.label=label.last_common[i+1] + + + # Fill POS tag in this node or previous one + if (label.n_commons >= old_n_commons): + current_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner) + else: + old_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner) + + old_n_commons=label.n_commons + old_level=current_level + last_label=label + + tree.inherit_tree() + + return tree \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/enc_const/naive_incremental.py b/tania_scripts/supar/codelin/encs/enc_const/naive_incremental.py new file mode 100644 index 0000000000000000000000000000000000000000..8ceb06b39916535b0e15110abf66fe7d5997387f --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_const/naive_incremental.py @@ -0,0 +1,160 @@ +from supar.codelin.encs.abstract_encoding import ACEncoding +from supar.codelin.utils.constants import C_ABSOLUTE_ENCODING, C_ROOT_LABEL, C_CONFLICT_SEPARATOR, C_NONE_LABEL, C_DUMMY_END +from supar.codelin.models.const_label import C_Label +from supar.codelin.models.linearized_tree import LinearizedTree +from supar.codelin.models.const_tree import C_Tree + +import re + +class C_NaiveIncrementalEncoding(ACEncoding): + def __init__(self, separator, unary_joiner): + self.separator = separator + self.unary_joiner = unary_joiner + + def __str__(self): + return "Constituent Naive Incremental Encoding" + + def get_unary_chain(self, postag): + unary_chain = None + leaf_unary_chain = postag.split(self.unary_joiner) + + if len(leaf_unary_chain)>1: + unary_list = [] + for element in leaf_unary_chain[:-1]: + unary_list.append(element.split("##")[0]) + + unary_chain = self.unary_joiner.join(unary_list) + postag = leaf_unary_chain[len(leaf_unary_chain)-1] + + return unary_chain, postag + + def get_features(self, node, feature_marker="##", feature_splitter="|"): + postag_split = node.split(feature_marker) + feats = None + + if len(postag_split) > 1: + postag = re.sub(r'[0-9]+', '', postag_split[0]) + feats = postag_split[1].split(feature_splitter) + else: + postag = re.sub(r'[0-9]+', '', node) + return postag, feats + + def clean_last_common(self, node, feature_marker="##"): + node = re.sub(r'[0-9]+', '', node) + last_common = node.split(feature_marker)[0] + return last_common + + def encode(self, constituent_tree): + constituent_tree.reverse_tree() + leaf_paths = constituent_tree.path_to_leaves(collapse_unary=True, unary_joiner=self.unary_joiner) + lc_tree = LinearizedTree.empty_tree() + + for i in range(1, len(leaf_paths)): + path_a = leaf_paths[i-1] + path_b = leaf_paths[i] + + last_common = "" + n_commons = 0 + + for a,b in zip(path_a, path_b): + if (a!=b): + # Remove the digits and aditional feats in the last common node + last_common = self.clean_last_common(last_common) + + # Get word and POS tag + word = path_a[-1] + postag = path_a[-2] + + # Build the Leaf Unary Chain + unary_chain, postag = self.get_unary_chain(postag) + + # Clean the POS Tag and extract additional features + postag, feats = self.get_features(postag) + + # Append the data + c_label = (C_Label(n_commons, last_common, unary_chain, C_ABSOLUTE_ENCODING, self.separator, self.unary_joiner)) + lc_tree.add_row(word, postag, feats, c_label) + + break + + # Store Last Common and increase n_commons + # Note: When increasing n_commons use the number from split the collapsed chains + n_commons += len(a.split(self.unary_joiner)) + last_common = a + + # reverse and return + lc_tree.reverse_tree(ignore_bos_eos=False) + return lc_tree + + def decode(self, linearized_tree): + # Check valid labels + if not linearized_tree: + print("[*] Error while decoding: Null tree.") + return + + # Create constituent tree + tree = C_Tree(C_ROOT_LABEL, []) + current_level = tree + + old_n_commons=0 + old_level=None + + linearized_tree.reverse_tree(ignore_bos_eos=False) + for word, postag, feats, label in linearized_tree.iterrows(): + + # Descend through the tree until reach the level indicated by last_common + current_level = tree + for level_index in range(label.n_commons): + if (current_level.is_terminal()) or (level_index >= old_n_commons): + current_level.add_child(C_Tree(C_NONE_LABEL, [])) + + current_level = current_level.r_child() + + # Split the Last Common field of the Label in case it has a Unary Chain Collapsed + label.last_common = label.last_common.split(self.unary_joiner) + + if len(label.last_common) == 1: + # If current level has no label yet, put the label + # If current level has label but different than this one, set it as a conflict + if (current_level.label == C_NONE_LABEL): + current_level.label = label.last_common[0] + else: + current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[0] + else: + current_level = tree + + # Descend to the beginning of the Unary Chain and fill it + descend_levels = label.n_commons - (len(label.last_common)) + 1 + + for level_index in range(descend_levels): + current_level = current_level.r_child() + + for i in range(len(label.last_common)-1): + if (current_level.label == C_NONE_LABEL): + current_level.label = label.last_common[i] + else: + current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[i] + current_level = current_level.r_child() + + # If we reach a POS tag, set it as child of the current chain + if current_level.is_preterminal(): + temp_current_level = current_level + current_level.label = label.last_common[i+1] + current_level.children = [temp_current_level] + + else: + current_level.label=label.last_common[i+1] + + # Fill POS tag in this node or previous one + if (label.n_commons >= old_n_commons): + current_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner) + + else: + old_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner) + + old_n_commons=label.n_commons + old_level=current_level + + tree.inherit_tree() + tree.reverse_tree() + return tree \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/enc_const/naive_relative.py b/tania_scripts/supar/codelin/encs/enc_const/naive_relative.py new file mode 100644 index 0000000000000000000000000000000000000000..0794e3be915243487d985a33af95a3be427552e4 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_const/naive_relative.py @@ -0,0 +1,152 @@ +from supar.codelin.encs.abstract_encoding import ACEncoding +from supar.codelin.utils.constants import C_RELATIVE_ENCODING, C_ROOT_LABEL, C_CONFLICT_SEPARATOR, C_NONE_LABEL +from supar.codelin.models.const_label import C_Label +from supar.codelin.models.linearized_tree import LinearizedTree +from supar.codelin.models.const_tree import C_Tree + +import re + +class C_NaiveRelativeEncoding(ACEncoding): + def __init__(self, separator, unary_joiner): + self.separator = separator + self.unary_joiner = unary_joiner + + def __str__(self): + return "Constituent Naive Relative Encoding" + + def encode(self, constituent_tree): + leaf_paths = constituent_tree.path_to_leaves(collapse_unary=True, unary_joiner=self.unary_joiner) + lc_tree = LinearizedTree.empty_tree() + + last_n_common=0 + for i in range(0, len(leaf_paths)-1): + path_a=leaf_paths[i] + path_b=leaf_paths[i+1] + + last_common="" + n_commons=0 + for a,b in zip(path_a, path_b): + + if (a!=b): + # Remove the digits and aditional feats in the last common node + last_common = re.sub(r'[0-9]+', '', last_common) + last_common = last_common.split("##")[0] + + # Get word and POS tag + word = path_a[-1] + postag = path_a[-2] + + # Build the Leaf Unary Chain + unary_chain = None + leaf_unary_chain = postag.split(self.unary_joiner) + if len(leaf_unary_chain)>1: + unary_list = [] + for element in leaf_unary_chain[:-1]: + unary_list.append(element.split("##")[0]) + + unary_chain = self.unary_joiner.join(unary_list) + postag = leaf_unary_chain[len(leaf_unary_chain)-1] + + # Clean the POS Tag and extract additional features + postag_split = postag.split("##") + feats = None + + if len(postag_split) > 1: + postag = re.sub(r'[0-9]+', '', postag_split[0]) + feats = postag_split[1].split("|") + else: + postag = re.sub(r'[0-9]+', '', postag) + + c_label = C_Label((n_commons-last_n_common), last_common, unary_chain, C_RELATIVE_ENCODING, self.separator, self.unary_joiner) + lc_tree.add_row(word, postag, feats, c_label) + + last_n_common=n_commons + break + + # Store Last Common and increase n_commons + # Note: When increasing n_commons use the number from split the collapsed chains + n_commons += len(a.split(self.unary_joiner)) + last_common = a + + return lc_tree + + def decode(self, linearized_tree): + # Check valid labels + if not linearized_tree: + print("[*] Error while decoding: Null tree.") + return + + # Create constituent tree + tree = C_Tree(C_ROOT_LABEL, []) + current_level = tree + + old_n_commons=0 + old_level=None + + is_first = True + last_label = None + + for word, postag, feats, label in linearized_tree.iterrows(): + # Convert the labels to absolute scale + if last_label!=None: + label.to_absolute(last_label) + + # First label must have a positive n_commons value + if is_first and label.n_commons <= 0: + label.n_commons = 1 + + # Descend through the tree until reach the level indicated by last_common + current_level = tree + for level_index in range(label.n_commons): + if (current_level.is_terminal()) or (level_index >= old_n_commons): + current_level.add_child(C_Tree(C_NONE_LABEL, [])) + current_level = current_level.r_child() + + # Split the Last Common field of the Label in case it has a Unary Chain Collapsed + label.last_common = label.last_common.split(self.unary_joiner) + + if len(label.last_common)==1: + if (current_level.label==C_NONE_LABEL): + current_level.label=label.last_common[0].rstrip() + else: + current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[0] + + if len(label.last_common)>1: + current_level = tree + + # problem when n_commons predicted is LESS than the number of last commons predicted + descend_levels = label.n_commons - (len(label.last_common)) + 1 + for level_index in range(descend_levels): + current_level = current_level.r_child() + for i in range(len(label.last_common)-1): + if (current_level.label == C_NONE_LABEL): + current_level.label = label.last_common[i] + else: + current_level.label = current_level.label+C_CONFLICT_SEPARATOR+label.last_common[i] + + if len(current_level.children)>0: + current_level = current_level.r_child() + + # If we reach a POS tag, set it as child of the current chain + if current_level.is_preterminal(): + temp_current_level_children = current_level.children + current_level.label = label.last_common[i+1] + current_level.children = temp_current_level_children + for c in temp_current_level_children: + c.parent = current_level + else: + current_level.label=label.last_common[i+1] + + # Fill POS tag in this node or previous one + if (label.n_commons >= old_n_commons): + current_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner) + else: + old_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner) + + old_n_commons=label.n_commons + old_level=current_level + + last_label=label + + tree.inherit_tree() + return tree \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__init__.py b/tania_scripts/supar/codelin/encs/enc_deps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd652c52beb69640767bbcf8201eaa1800818c0 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_deps/__init__.py @@ -0,0 +1,6 @@ +## Dependency Encodings Package +from .naive_absolute import D_NaiveAbsoluteEncoding +from .naive_relative import D_NaiveRelativeEncoding +from .brk_based import D_BrkBasedEncoding +from .pos_based import D_PosBasedEncoding +from .brk2p_based import D_Brk2PBasedEncoding \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5d526f8782acc6ddb352bc54ce8ef508f64fe80 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ce46dfd865d40f78b478a5a314786080bc9266c Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50aeb39bab298e183c8d9afec05cac65fb3f9a5e Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbbe00aafec12095b89292d53a158512e1969458 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk2p_based.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08db178e6917bb38a99b0d745ffe3f08121eeeec Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c2dea9b223e0a9c2a7ac3e40c647b6abc9ca0a Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/brk_based.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d7369766076aec1dad859894574c0db23c85e29 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c239f14078cb1d2da6df29f582713a3c14d788e Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_absolute.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda4d2e8d01c87252607b4b5da3424fdb7ff31f0 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..269a35bd9bb6605006c7431ea0d320ed293302d0 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/naive_relative.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-310.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82512a33f500fe06e096da4cc221deeac69f42c8 Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-311.pyc b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..741e084997009584513b4b9673da439210e63bce Binary files /dev/null and b/tania_scripts/supar/codelin/encs/enc_deps/__pycache__/pos_based.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/encs/enc_deps/brk2p_based.py b/tania_scripts/supar/codelin/encs/enc_deps/brk2p_based.py new file mode 100644 index 0000000000000000000000000000000000000000..665e9e6c61fd6a729d023349d203f87ae3765546 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_deps/brk2p_based.py @@ -0,0 +1,215 @@ +from supar.codelin.encs.abstract_encoding import ADEncoding +from supar.codelin.utils.constants import D_2P_GREED, D_2P_PROP, D_NONE_LABEL +from supar.codelin.models.deps_label import D_Label +from supar.codelin.models.deps_tree import D_Tree +from supar.codelin.models.linearized_tree import LinearizedTree + + +class D_Brk2PBasedEncoding(ADEncoding): + def __init__(self, separator, displacement, planar_alg): + if planar_alg and planar_alg not in [D_2P_GREED, D_2P_PROP]: + print("[*] Error: Unknown planar separation algorithm") + exit(1) + super().__init__(separator) + self.displacement = displacement + self.planar_alg = planar_alg + + def __str__(self): + return "Dependency 2-Planar Bracketing Based Encoding" + + def get_next_edge(self, dep_tree, idx_l, idx_r): + next_arc=None + + if dep_tree[idx_l].head==idx_r: + next_arc = dep_tree[idx_l] + + elif dep_tree[idx_r].head==idx_l: + next_arc = dep_tree[idx_r] + + return next_arc + + def two_planar_propagate(self, nodes): + p1=[] + p2=[] + fp1=[] + fp2=[] + + for i in range(0, (len(nodes))): + for j in range(i, -1, -1): + # if the node in position 'i' has an arc to 'j' + # or node in position 'j' has an arc to 'i' + next_arc=self.get_next_edge(nodes, i, j) + if next_arc is None: + continue + else: + # check restrictions + if next_arc not in fp1: + p1.append(next_arc) + fp1, fp2 = self.propagate(nodes, fp1, fp2, next_arc, 2) + + elif next_arc not in fp2: + p2.append(next_arc) + fp1, fp2 = self.propagate(nodes, fp1, fp2, next_arc, 1) + return p1, p2 + def propagate(self, nodes, fp1, fp2, current_edge, i): + # add the current edge to the forbidden plane opposite to the plane + # where the node has already been added + fpi = None + fp3mi= None + if i==1: + fpi = fp1 + fp3mi= fp2 + if i==2: + fpi = fp2 + fp3mi= fp1 + + fpi.append(current_edge) + + # add all nodes from the dependency graph that crosses the current edge + # to the corresponding forbidden plane + for node in nodes: + if current_edge.check_cross(node): + if node not in fp3mi: + (fp1, fp2)=self.propagate(nodes, fp1, fp2, node, 3-i) + + return fp1, fp2 + + def two_planar_greedy(self, dep_tree): + plane_1 = [] + plane_2 = [] + + for i in range(len(dep_tree)): + for j in range(i, -1, -1): + # if the node in position 'i' has an arc to 'j' + # or node in position 'j' has an arc to 'i' + next_arc = self.get_next_edge(dep_tree, i, j) + if next_arc is None: + continue + + else: + cross_plane_1 = False + cross_plane_2 = False + for node in plane_1: + cross_plane_1 = cross_plane_1 or next_arc.check_cross(node) + for node in plane_2: + cross_plane_2 = cross_plane_2 or next_arc.check_cross(node) + + if not cross_plane_1: + plane_1.append(next_arc) + elif not cross_plane_2: + plane_2.append(next_arc) + + # processs them separately + return plane_1,plane_2 + + + def encode(self, dep_tree): + # create brackets array + n_nodes = len(dep_tree) + labels_brk = [""] * (n_nodes + 1) + + # separate the planes + if self.planar_alg==D_2P_GREED: + p1_nodes, p2_nodes = self.two_planar_greedy(dep_tree) + elif self.planar_alg==D_2P_PROP: + p1_nodes, p2_nodes = self.two_planar_propagate(dep_tree) + + # get brackets separatelly + labels_brk = self.encode_step(p1_nodes, labels_brk, ['>','/','\\','<']) + labels_brk = self.encode_step(p2_nodes, labels_brk, ['>*','/*','\\*','<*']) + + # merge and obtain labels + lbls=[] + dep_tree.remove_dummy() + for node in dep_tree: + current = D_Label(labels_brk[node.id], node.relation, self.separator) + lbls.append(current) + return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), lbls, len(lbls)) + + def encode_step(self, p, lbl_brk, brk_chars): + for node in p: + # skip root relations (optional?) + if node.head==0: + continue + if node.id < node.head: + if self.displacement: + lbl_brk[node.id+1]+=brk_chars[3] + else: + lbl_brk[node.id]+=brk_chars[3] + + lbl_brk[node.head]+=brk_chars[2] + else: + if self.displacement: + lbl_brk[node.head+1]+=brk_chars[1] + else: + lbl_brk[node.head]+=brk_chars[1] + + lbl_brk[node.id]+=brk_chars[0] + return lbl_brk + + def decode(self, lin_tree): + decoded_tree = D_Tree.empty_tree(len(lin_tree)+1) + + # create plane stacks + l_stack_p1=[] + l_stack_p2=[] + r_stack_p1=[] + r_stack_p2=[] + + current_node=1 + + for word, postag, features, label in lin_tree.iterrows(): + brks = list(label.xi) if label.xi != D_NONE_LABEL else [] + temp_brks=[] + + for i in range(0, len(brks)): + current_char=brks[i] + if brks[i]=="*": + current_char=temp_brks.pop()+brks[i] + temp_brks.append(current_char) + + brks=temp_brks + + # set parameters to the node + decoded_tree.update_word(current_node, word) + decoded_tree.update_upos(current_node, postag) + decoded_tree.update_relation(current_node, label.li) + + # fill the relation using brks + for char in brks: + if char == "<": + node_id=current_node + (-1 if self.displacement else 0) + r_stack_p1.append((node_id,char)) + + if char == "\\": + head_id = r_stack_p1.pop()[0] if len(r_stack_p1)>0 else 0 + decoded_tree.update_head(head_id, current_node) + + if char =="/": + node_id=current_node + (-1 if self.displacement else 0) + l_stack_p1.append((node_id,char)) + + if char == ">": + head_id = l_stack_p1.pop()[0] if len(l_stack_p1)>0 else 0 + decoded_tree.update_head(current_node, head_id) + + if char == "<*": + node_id=current_node + (-1 if self.displacement else 0) + r_stack_p2.append((node_id,char)) + + if char == "\\*": + head_id = r_stack_p2.pop()[0] if len(r_stack_p2)>0 else 0 + decoded_tree.update_head(head_id, current_node) + + if char =="/*": + node_id=current_node + (-1 if self.displacement else 0) + l_stack_p2.append((node_id,char)) + + if char == ">*": + head_id = l_stack_p2.pop()[0] if len(l_stack_p2)>0 else 0 + decoded_tree.update_head(current_node, head_id) + + current_node+=1 + + decoded_tree.remove_dummy() + return decoded_tree diff --git a/tania_scripts/supar/codelin/encs/enc_deps/brk_based.py b/tania_scripts/supar/codelin/encs/enc_deps/brk_based.py new file mode 100644 index 0000000000000000000000000000000000000000..58758155d91e6fe7b43544f6d8fb0eb262879e65 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_deps/brk_based.py @@ -0,0 +1,87 @@ +from supar.codelin.encs.abstract_encoding import ADEncoding +from supar.codelin.models.deps_label import D_Label +from supar.codelin.models.deps_tree import D_Tree +from supar.codelin.utils.constants import D_NONE_LABEL +from supar.codelin.models.linearized_tree import LinearizedTree + +class D_BrkBasedEncoding(ADEncoding): + + def __init__(self, separator, displacement): + super().__init__(separator) + self.displacement = displacement + + def __str__(self): + return "Dependency Bracketing Based Encoding" + + + def encode(self, dep_tree): + n_nodes = len(dep_tree) + labels_brk = [""] * (n_nodes + 1) + encoded_labels = [] + + # compute brackets array + # brackets array should be sorted ? + dep_tree.remove_dummy() + for node in dep_tree: + # skip root relations (optional?) + if node.head == 0: + continue + + if node.is_left_arc(): + labels_brk[node.id + (1 if self.displacement else 0)]+='<' + labels_brk[node.head]+='\\' + + else: + labels_brk[node.head + (1 if self.displacement else 0)]+='/' + labels_brk[node.id]+='>' + + # encode labels + for node in dep_tree: + li = node.relation + xi = labels_brk[node.id] + + current = D_Label(xi, li, self.separator) + encoded_labels.append(current) + + return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), encoded_labels, len(encoded_labels)) + + def decode(self, lin_tree): + # Create an empty tree with n labels + decoded_tree = D_Tree.empty_tree(len(lin_tree)+1) + + l_stack = [] + r_stack = [] + + current_node = 1 + for word, postag, features, label in lin_tree.iterrows(): + + # get the brackets + brks = list(label.xi) if label.xi != D_NONE_LABEL else [] + + # set parameters to the node + decoded_tree.update_word(current_node, word) + decoded_tree.update_upos(current_node, postag) + decoded_tree.update_relation(current_node, label.li) + + # fill the relation using brks + for char in brks: + if char == "<": + node_id = current_node + (-1 if self.displacement else 0) + r_stack.append(node_id) + + if char == "\\": + head_id = r_stack.pop() if len(r_stack) > 0 else 0 + decoded_tree.update_head(head_id, current_node) + + if char =="/": + node_id = current_node + (-1 if self.displacement else 0) + l_stack.append(node_id) + + if char == ">": + head_id = l_stack.pop() if len(l_stack) > 0 else 0 + decoded_tree.update_head(current_node, head_id) + + current_node+=1 + + decoded_tree.remove_dummy() + return decoded_tree \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/enc_deps/naive_absolute.py b/tania_scripts/supar/codelin/encs/enc_deps/naive_absolute.py new file mode 100644 index 0000000000000000000000000000000000000000..e624bb34feeddfa2cb8c149bfbf4f45f3b813abf --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_deps/naive_absolute.py @@ -0,0 +1,42 @@ +from supar.codelin.encs.abstract_encoding import ADEncoding +from supar.codelin.models.deps_label import D_Label +from supar.codelin.models.linearized_tree import LinearizedTree +from supar.codelin.models.deps_tree import D_Tree +from supar.codelin.utils.constants import D_NONE_LABEL + +class D_NaiveAbsoluteEncoding(ADEncoding): + def __init__(self, separator): + super().__init__(separator) + + def __str__(self): + return "Dependency Naive Absolute Encoding" + + def encode(self, dep_tree): + encoded_labels = [] + dep_tree.remove_dummy() + + for node in dep_tree: + li = node.relation + xi = node.head + + current = D_Label(xi, li, self.separator) + encoded_labels.append(current) + + return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), encoded_labels, len(encoded_labels)) + + def decode(self, lin_tree): + dep_tree = D_Tree.empty_tree(len(lin_tree)+1) + + i=1 + for word, postag, features, label in lin_tree.iterrows(): + if label.xi == D_NONE_LABEL: + label.xi = 0 + + dep_tree.update_word(i, word) + dep_tree.update_upos(i, postag) + dep_tree.update_relation(i, label.li) + dep_tree.update_head(i, int(label.xi)) + i+=1 + + dep_tree.remove_dummy() + return dep_tree \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/enc_deps/naive_relative.py b/tania_scripts/supar/codelin/encs/enc_deps/naive_relative.py new file mode 100644 index 0000000000000000000000000000000000000000..6880f271246ae0d1ef2ce07132b1ea255708b60c --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_deps/naive_relative.py @@ -0,0 +1,48 @@ +from supar.codelin.encs.abstract_encoding import ADEncoding +from supar.codelin.models.deps_label import D_Label +from supar.codelin.models.linearized_tree import LinearizedTree +from supar.codelin.models.deps_tree import D_Tree +from supar.codelin.utils.constants import D_NONE_LABEL + +class D_NaiveRelativeEncoding(ADEncoding): + def __init__(self, separator, hang_from_root): + super().__init__(separator) + self.hfr = hang_from_root + + def __str__(self): + return "Dependency Naive Relative Encoding" + + def encode(self, dep_tree): + encoded_labels = [] + dep_tree.remove_dummy() + for node in dep_tree: + li = node.relation + xi = node.delta_head() + + if node.relation == 'root' and self.hfr: + xi = D_NONE_LABEL + + current = D_Label(xi, li, self.separator) + encoded_labels.append(current) + + return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), encoded_labels, len(encoded_labels)) + + def decode(self, lin_tree): + dep_tree = D_Tree.empty_tree(len(lin_tree)+1) + + i = 1 + for word, postag, features, label in lin_tree.iterrows(): + if label.xi == D_NONE_LABEL: + # set as root + dep_tree.update_head(i, 0) + else: + dep_tree.update_head(i, int(label.xi)+(i)) + + + dep_tree.update_word(i, word) + dep_tree.update_upos(i, postag) + dep_tree.update_relation(i, label.li) + i+=1 + + dep_tree.remove_dummy() + return dep_tree \ No newline at end of file diff --git a/tania_scripts/supar/codelin/encs/enc_deps/pos_based.py b/tania_scripts/supar/codelin/encs/enc_deps/pos_based.py new file mode 100644 index 0000000000000000000000000000000000000000..714efd8d1aff41234ef06409474d491ff96d23f9 --- /dev/null +++ b/tania_scripts/supar/codelin/encs/enc_deps/pos_based.py @@ -0,0 +1,89 @@ +from supar.codelin.encs.abstract_encoding import ADEncoding +from supar.codelin.models.deps_label import D_Label +from supar.codelin.models.deps_tree import D_Tree +from supar.codelin.models.linearized_tree import LinearizedTree +from supar.codelin.utils.constants import D_POSROOT, D_NONE_LABEL + +POS_ROOT_LABEL = "0--ROOT" + +class D_PosBasedEncoding(ADEncoding): + def __init__(self, separator): + super().__init__(separator) + + def __str__(self) -> str: + return "Dependency Part-of-Speech Based Encoding" + + def encode(self, dep_tree): + + print("upar/codelin/encs/enc_deps/pos_based.py") + encoded_labels = [] + + for node in dep_tree: + if node.id == 0: + # skip dummy root + continue + + li = node.relation + pi = dep_tree[node.head].upos + oi = 0 + + # move left or right depending if the node + # dependency edge is to the left or to the right + + step = 1 if node.id < node.head else -1 + for i in range(node.id + step, node.head + step, step): + if pi == dep_tree[i].upos: + oi += step + + xi = str(oi)+"--"+pi + + current = D_Label(xi, li, self.separator) + encoded_labels.append(current) + + dep_tree.remove_dummy() + return LinearizedTree(dep_tree.get_words(), dep_tree.get_postags(), dep_tree.get_feats(), encoded_labels, len(encoded_labels)) + + def decode(self, lin_tree): + dep_tree = D_Tree.empty_tree(len(lin_tree)+1) + + i = 1 + postags = lin_tree.postags + for word, postag, features, label in lin_tree.iterrows(): + node_id = i + if label.xi == D_NONE_LABEL: + label.xi = POS_ROOT_LABEL + + dep_tree.update_word(node_id, word) + dep_tree.update_upos(node_id, postag) + dep_tree.update_relation(node_id, label.li) + + oi, pi = label.xi.split('--') + oi = int(oi) + + # Set head for root + if (pi==D_POSROOT or oi==0): + dep_tree.update_head(node_id, 0) + i+=1 + continue + + # Compute head position + target_oi = oi + + step = 1 if oi > 0 else -1 + stop_point = (len(postags)+1) if oi > 0 else 0 + + for j in range(node_id+step, stop_point, step): + if (pi == postags[j-1]): + target_oi -= step + + if (target_oi==0): + break + + head_id = j + dep_tree.update_head(node_id, head_id) + + i+=1 + + + dep_tree.remove_dummy() + return dep_tree \ No newline at end of file diff --git a/tania_scripts/supar/codelin/models/__init__.py b/tania_scripts/supar/codelin/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..055dc41308b5847ea209b4ae31a538031037b6ad Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7f007831d62522735be047094f4cf4fdcc4000c Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ab31bdd2279ddff308fa26008115f5475640b1c Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6816f095c20760ddde9f1421313ff80f2f8e887 Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/const_label.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..069e2ef3ba4d3d9cbe4ae8d6dea1e8105dfb3d47 Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..376f81b9df97423cdc912c3eac7e7063b9e838a0 Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/const_tree.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a316c2c1296885e31867d3fd9fabd908a3cc177 Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9643deb605dccdecaca4931fdcf8ec1c1ad66a6b Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/deps_label.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8e4df612a26f2cf508b2bb5fa7e85648a86b36d Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca3679ea018cddc145eefe45675c608e6475204d Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/deps_tree.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-310.pyc b/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6b35816c4e91140d23f0a05ada4956a3eca4872 Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-311.pyc b/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30e902912d8da97efe2c282ebdf01db36bf1e563 Binary files /dev/null and b/tania_scripts/supar/codelin/models/__pycache__/linearized_tree.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/models/const_label.py b/tania_scripts/supar/codelin/models/const_label.py new file mode 100644 index 0000000000000000000000000000000000000000..1df73ca2919dc76c935c1e1641554be3869e8ae1 --- /dev/null +++ b/tania_scripts/supar/codelin/models/const_label.py @@ -0,0 +1,38 @@ +from supar.codelin.utils.constants import C_ABSOLUTE_ENCODING, C_RELATIVE_ENCODING, C_NONE_LABEL + +class C_Label: + def __init__(self, nc, lc, uc, et, sp, uj): + self.encoding_type = et + + self.n_commons = int(nc) + self.last_common = lc + self.unary_chain = uc if uc != C_NONE_LABEL else None + self.separator = sp + self.unary_joiner = uj + + def __repr__(self): + unary_str = self.unary_joiner.join([self.unary_chain]) if self.unary_chain else "" + return (str(self.n_commons) + ("*" if self.encoding_type==C_RELATIVE_ENCODING else "") + + self.separator + self.last_common + (self.separator + unary_str if self.unary_chain else "")) + + def to_absolute(self, last_label): + self.n_commons+=last_label.n_commons + if self.n_commons<=0: + self.n_commons = 1 + + self.encoding_type=C_ABSOLUTE_ENCODING + + @staticmethod + def from_string(l, sep, uj): + label_components = l.split(sep) + + if len(label_components)== 2: + nc, lc = label_components + uc = None + else: + nc, lc, uc = label_components + + et = C_RELATIVE_ENCODING if '*' in nc else C_ABSOLUTE_ENCODING + nc = nc.replace("*","") + return C_Label(nc, lc, uc, et, sep, uj) + \ No newline at end of file diff --git a/tania_scripts/supar/codelin/models/const_tree.py b/tania_scripts/supar/codelin/models/const_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..e5466582822ecf47bdf76131495191c49da94946 --- /dev/null +++ b/tania_scripts/supar/codelin/models/const_tree.py @@ -0,0 +1,414 @@ +from supar.codelin.utils.constants import C_END_LABEL, C_START_LABEL, C_NONE_LABEL +from supar.codelin.utils.constants import C_CONFLICT_SEPARATOR, C_STRAT_MAX, C_STRAT_FIRST, C_STRAT_LAST, C_NONE_LABEL, C_ROOT_LABEL +import copy + +class C_Tree: + def __init__(self, label, children=[], feats=None): + self.parent = None + self.label = label + self.children = children + self.features = feats + +# Adders and deleters + def add_child(self, child): + ''' + Function that adds a child to the current tree + ''' + if type(child) is list: + for c in child: + self.add_child(c) + elif type(child) is C_Tree: + self.children.append(child) + child.parent = self + + else: + raise TypeError("[!] Child must be a ConstituentTree or a list of Constituent Trees") + + def add_left_child(self, child): + ''' + Function that adds a child to the left of the current tree + ''' + if type(child) is not C_Tree: + raise TypeError("[!] Child must be a ConstituentTree") + + self.children = [child] + self.children + child.parent = self + + def del_child(self, child): + ''' + Function that deletes a child from the current tree + without adding its children to the current tree + ''' + if type(child) is not C_Tree: + raise TypeError("[!] Child must be a ConstituentTree") + + self.children.remove(child) + child.parent = None + +# Getters + def r_child(self): + ''' + Function that returns the rightmost child of a tree + ''' + return self.children[len(self.children)-1] + + def l_child(self): + ''' + Function that returns the leftmost child of a tree + ''' + return self.children[0] + + def r_siblings(self): + ''' + Function that returns the right siblings of a tree + ''' + return self.parent.children[self.parent.children.index(self)+1:] + + def l_siblings(self): + ''' + Function that returns the left siblings of a tree + ''' + return self.parent.children[:self.parent.children.index(self)] + + def get_root(self): + ''' + Function that returns the root of a tree + ''' + if self.parent is None: + return self + else: + return self.parent.get_root() + + def extract_features(self, f_mark = "##", f_sep = "|"): + # go through all pre-terminal nodes + # of the tree + for node in self.get_preterminals(): + + if f_mark in node.label: + node.features = {} + label = node.label.split(f_mark)[0] + features = node.label.split(f_mark)[1] + + node.label = label + + # add features to the tree + for feature in features.split(f_sep): + + if feature == "_": + continue + + key = feature.split("=")[0] + value = feature.split("=")[1] + + node.features[key]=value + + def get_feature_names(self): + ''' + Returns a set containing all feature names + for the tree + ''' + feat_names = set() + + for child in self.children: + feat_names = feat_names.union(child.get_feature_names()) + if self.features is not None: + feat_names = feat_names.union(set(self.features.keys())) + + return feat_names + +# Word and Postags getters + def get_words(self): + ''' + Function that returns the terminal nodes of a tree + ''' + if self.is_terminal(): + return [self.label] + else: + return [node for child in self.children for node in child.get_words()] + + def get_postags(self): + ''' + Function that returns the preterminal nodes of a tree + ''' + if self.is_preterminal(): + return [self.label] + else: + return [node for child in self.children for node in child.get_postags()] + +# Terminal checking + def is_terminal(self): + ''' + Function that checks if a tree is a terminal + ''' + return len(self.children) == 0 + + def is_preterminal(self): + ''' + Function that checks if a tree is a preterminal + ''' + return len(self.children) == 1 and self.children[0].is_terminal() + +# Terminal getters + def get_terminals(self): + ''' + Function that returns the terminal nodes of a tree + ''' + if self.is_terminal(): + return [self] + else: + return [node for child in self.children for node in child.get_terminals()] + + def get_preterminals(self): + ''' + Function that returns the terminal nodes of a tree + ''' + if self.is_preterminal(): + return [self] + else: + return [node for child in self.children for node in child.get_preterminals()] + +# Tree processing + def collapse_unary(self, unary_joiner="+"): + ''' + Function that collapses unary chains + into single nodes using a unary_joiner as join character + ''' + for child in self.children: + child.collapse_unary(unary_joiner) + if len(self.children)==1 and not self.is_preterminal(): + self.label += unary_joiner + self.children[0].label + self.children = self.children[0].children + + def inherit_tree(self): + ''' + Removes the top node of the tree and delegates it + to its firstborn child. + + (S (NP (NNP John)) (VP (VBD died))) => (NP (NNP John)) + ''' + self.label = self.children[0].label + self.children = self.children[0].children + + def add_end_node(self): + ''' + Function that adds a dummy end node to the + rightmost part of the tree + ''' + self.add_child(C_Tree(C_END_LABEL, [])) + + def add_start_node(self): + ''' + Function that adds a dummy start node to the leftmost + part of the tree + ''' + self.add_left_child(C_Tree(C_START_LABEL, [])) + + def path_to_leaves(self, collapse_unary=True, unary_joiner="+"): + ''' + Function that given a Tree returns a list of paths + from the root to the leaves, encoding a level index into + nodes to make them unique. + ''' + self.add_end_node() + + if collapse_unary: + self.collapse_unary(unary_joiner) + + paths = self.path_to_leaves_rec([],[],0) + return paths + + def path_to_leaves_rec(self, current_path, paths, idx): + ''' + Recursive step of the path_to_leaves function where we store + the common path based on the current node + ''' + # pass by value + path = copy.deepcopy(current_path) + + if (len(self.children)==0): + # we are at a leaf. store the path in a new list + path.append(self.label) + paths.append(path) + else: + path.append(self.label+str(idx)) + for child in self.children: + child.path_to_leaves_rec(path, paths, idx) + idx+=1 + return paths + + def fill_pos_nodes(self, postag, word, unary_chain, unary_joiner): + if self.label == postag: + # if the current level is already a postag level. This may happen on + # trees shaped as (NP tree) that exist on the SPMRL treebanks + self.children.append(C_Tree(word, [])) + return + + if unary_chain: + unary_chain = unary_chain.split(unary_joiner) + unary_chain.reverse() + pos_tree = C_Tree(postag, [C_Tree(word, [])]) + for node in unary_chain: + temp_tree = C_Tree(node, [pos_tree]) + pos_tree = temp_tree + else: + pos_tree = C_Tree(postag, [C_Tree(word, [])]) + + self.add_child(pos_tree) + + def renounce_children(self): + ''' + Function that deletes current tree from its parent + and adds its children to the parent + ''' + self.parent.children = self.l_siblings() + self.children + self.r_siblings() + for child in self.children: + child.parent = self.parent + + + def prune_nones(self): + """ + Return a copy of the tree without + null nodes (nodes with label C_NONE_LABEL) + """ + if self.label != C_NONE_LABEL: + t = C_Tree(self.label, []) + new_childs = [c.prune_nones() for c in self.children] + t.add_child(new_childs) + return t + + else: + return [c.prune_nones() for c in self.children] + + def remove_conflicts(self, conflict_strat): + ''' + Removes all conflicts in the label of the tree generated + during the decoding process. Conflicts will be signaled by -||- + string. + ''' + for c in self.children: + if type(c) is C_Tree: + c.remove_conflicts(conflict_strat) + if C_CONFLICT_SEPARATOR in self.label: + labels = self.label.split(C_CONFLICT_SEPARATOR) + + if conflict_strat == C_STRAT_MAX: + self.label = max(set(labels), key=labels.count) + if conflict_strat == C_STRAT_FIRST: + self.label = labels[0] + if conflict_strat == C_STRAT_LAST: + self.label = labels[len(labels)-1] + + def postprocess_tree(self, conflict_strat, clean_nulls=True, default_root="S"): + ''' + Returns a C_Tree object with conflicts in node labels removed + and with NULL nodes cleaned. + ''' + if clean_nulls: + if self.label == C_NONE_LABEL or self.label==C_ROOT_LABEL: + self.label = default_root + t = self.prune_nones() + else: + t = self + t.remove_conflicts(conflict_strat) + return t + + # print( fix_tree) + + def reverse_tree(self): + ''' + Reverses the order of all the tree children + ''' + for c in self.children: + if type(c) is C_Tree: + c.reverse_tree() + self.children.reverse() + +# Printing and python-related functions + def __str__(self): + if len(self.children) == 0: + label_str = self.label + + if self.features is not None: + features_str = "##" + "|".join([key+"="+value for key,value in self.features.items()]) + + label_str = label_str.replace("(","-LRB-") + label_str = label_str.replace(")","-RRB-") + else: + label_str = "(" + self.label + " " + if self.features is not None: + features_str = "##"+ "|".join([key+"="+value for key,value in self.features.items()]) + + label_str += " ".join([str(child) for child in self.children]) + ")" + return label_str + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + if isinstance(other, C_Tree): + return self.label == other.label and self.children == other.children + return False + + def __hash__(self): + return hash((self.label, tuple(self.children))) + + def __len__(self): + return len(self.children) + + def __iter__(self): + yield self.label + for child in self.children: + yield child + + def __contains__(self, item): + return item in self.label or item in self.children + + +# Tree creation + @staticmethod + def from_string(s): + s = s.replace("(","( ") + s = s.replace(")"," )") + s = s.split(" ") + + # create dummy label and append it to the stack + stack = [] + i=0 + while i < (len(s)): + if s[i]=="(": + # If we find a l_brk we create a new tree + # with label=next_word. Skip next_word. + w = s[i+1] + t = C_Tree(w, []) + stack.append(t) + i+=1 + + elif s[i]==")": + # If we find a r_brk set top of the stack + # as children to the second top of the stack + + t = stack.pop() + + if len(stack)==0: + return t + + pt = stack.pop() + pt.add_child(t) + stack.append(pt) + + else: + # If we find a word set it as children + # of the current tree. + t = stack.pop() + w = s[i] + c = C_Tree(w, []) + t.add_child(c) + stack.append(t) + + i+=1 + return t + +# Default trees + @staticmethod + def empty_tree(): + return C_Tree(C_NONE_LABEL, []) \ No newline at end of file diff --git a/tania_scripts/supar/codelin/models/deps_label.py b/tania_scripts/supar/codelin/models/deps_label.py new file mode 100644 index 0000000000000000000000000000000000000000..ada0b83c036cabb6bff8ebf715d1c068cc2f2c82 --- /dev/null +++ b/tania_scripts/supar/codelin/models/deps_label.py @@ -0,0 +1,18 @@ +class D_Label: + def __init__(self, xi, li, sp): + self.separator = sp + + self.xi = xi # dependency relation + self.li = li # encoding + + def __repr__(self): + return f'{self.xi}{self.separator}{self.li}' + + @staticmethod + def from_string(lbl_str, sep): + xi, li = lbl_str.split(sep) + return D_Label(xi, li, sep) + + @staticmethod + def empty_label(separator): + return D_Label("", "", separator) diff --git a/tania_scripts/supar/codelin/models/deps_tree.py b/tania_scripts/supar/codelin/models/deps_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..43130267b3ea24f27855f420ae2ce4abbbf08f94 --- /dev/null +++ b/tania_scripts/supar/codelin/models/deps_tree.py @@ -0,0 +1,371 @@ +from supar.codelin.utils.constants import D_ROOT_HEAD, D_NULLHEAD, D_ROOT_REL, D_POSROOT, D_EMPTYREL + +class D_Node: + def __init__(self, wid, form, lemma=None, upos=None, xpos=None, feats=None, head=None, deprel=None, deps=None, misc=None): + self.id = int(wid) # word id + + self.form = form if form else "_" # word + self.lemma = lemma if lemma else "_" # word lemma/stem + self.upos = upos if upos else "_" # universal postag + self.xpos = xpos if xpos else "_" # language_specific postag + self.feats = self.parse_feats(feats) if feats else "_" # morphological features + + self.head = int(head) # id of the word that depends on + self.relation = deprel # type of relation with head + + self.deps = deps if deps else "_" # enhanced dependency graph + self.misc = misc if misc else "_" # miscelaneous data + + def is_left_arc(self): + return self.head > self.id + + def delta_head(self): + return self.head - self.id + + def parse_feats(self, feats): + if feats == '_': + return [None] + else: + return [x for x in feats.split('|')] + + def check_cross(self, other): + if ((self.head == other.head) or (self.head==other.id)): + return False + + r_id_inside = (other.head < self.id < other.id) + l_id_inside = (other.id < self.id < other.head) + + id_inside = r_id_inside or l_id_inside + + r_head_inside = (other.head < self.head < other.id) + l_head_inside = (other.id < self.head < other.head) + + head_inside = r_head_inside or l_head_inside + + return head_inside^id_inside + + def __repr__(self): + return '\t'.join(str(e) for e in list(self.__dict__.values()))+'\n' + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + + @staticmethod + def from_string(conll_str): + wid,form,lemma,upos,xpos,feats,head,deprel,deps,misc = conll_str.split('\t') + return D_Node(int(wid), form, lemma, upos, xpos, feats, int(head), deprel, deps, misc) + + @staticmethod + def dummy_root(): + return D_Node(0, D_POSROOT, None, D_POSROOT, None, None, 0, D_EMPTYREL, None, None) + + @staticmethod + def empty_node(): + return D_Node(0, None, None, None, None, None, 0, None, None, None) + +class D_Tree: + def __init__(self, nodes): + self.nodes = nodes + +# getters + def get_node(self, id): + return self.nodes[id-1] + + def get_edges(self): + ''' + Return sentence dependency edges as a tuple + shaped as ((d,h),r) where d is the dependant of the relation, + h the head of the relation and r the relationship type + ''' + return list(map((lambda x :((x.id, x.head), x.relation)), self.nodes)) + + def get_arcs(self): + ''' + Return sentence dependency edges as a tuple + shaped as (d,h) where d is the dependant of the relation, + and h the head of the relation. + ''' + return list(map((lambda x :(x.id, x.head)), self.nodes)) + + def get_relations(self): + ''' + Return a list of relationships betwee nodes + ''' + return list(map((lambda x :x.relation), self.nodes)) + + def get_sentence(self): + ''' + Return the sentence as a string + ''' + return " ".join(list(map((lambda x :x.form), self.nodes))) + + def get_words(self): + ''' + Returns the words of the sentence as a list + ''' + return list(map((lambda x :x.form), self.nodes)) + + def get_indexes(self): + ''' + Returns a list of integers representing the words of the + dependency tree + ''' + return list(map((lambda x :x.id), self.nodes)) + + def get_postags(self): + ''' + Returns the part of speech tags of the tree + ''' + return list(map((lambda x :x.upos), self.nodes)) + + def get_lemmas(self): + ''' + Returns the lemmas of the tree + ''' + return list(map((lambda x :x.lemma), self.nodes)) + + def get_heads(self): + ''' + Returns the heads of the tree + ''' + return list(map((lambda x :x.head), self.nodes)) + + def get_feats(self): + ''' + Returns the morphological features of the tree + ''' + return list(map((lambda x :x.feats), self.nodes)) + +# update functions + def append_node(self, node): + ''' + Append a node to the tree and sorts the nodes by id + ''' + self.nodes.append(node) + self.nodes.sort(key=lambda x: x.id) + + def update_head(self, node_id, head_value): + ''' + Update the head of a node indicated by its id + ''' + for node in self.nodes: + if node.id == node_id: + node.head = head_value + break + + def update_relation(self, node_id, relation_value): + ''' + Update the relation of a node indicated by its id + ''' + for node in self.nodes: + if node.id == node_id: + node.relation = relation_value + break + + def update_word(self, node_id, word): + ''' + Update the word of a node indicated by its id + ''' + for node in self.nodes: + if node.id == node_id: + node.form = word + break + + def update_upos(self, node_id, postag): + ''' + Update the upos field of a node indicated by its id + ''' + for node in self.nodes: + if node.id == node_id: + node.upos = postag + break + +# properties functions + def is_projective(self): + ''' + Returns a boolean indicating if the dependency tree + is projective (i.e. no edges are crossing) + ''' + arcs = self.get_arcs() + for (i,j) in arcs: + for (k,l) in arcs: + if (i,j) != (k,l) and min(i,j) < min(k,l) < max(i,j) < max(k,l): + return False + return True + +# postprocessing + def remove_dummy(self): + self.nodes = self.nodes[1:] + + def postprocess_tree(self, search_root_strat, allow_multi_roots=False): + ''' + Postprocess the tree by finding the root according to the selected + strategy and fixing cycles and out of bounds heads + ''' + # 1) Find the root + root = self.root_search(search_root_strat) + + # 2) Fix oob heads + self.fix_oob_heads() + + # 3) Fix cycles + self.fix_cycles(root) + + # 4) Set all null heads to root and remove other root candidates + for node in self.nodes: + if node.id == root: + node.head = 0 + continue + if node.head == D_NULLHEAD: + node.head = root + if not allow_multi_roots and node.head == 0: + node.head = root + + def root_search(self, search_root_strat): + ''' + Search for the root of the tree using the method indicated + ''' + root = 1 # Default root + for node in self.nodes: + if search_root_strat == D_ROOT_HEAD: + if node.head == 0: + root = node.id + break + + elif search_root_strat == D_ROOT_REL: + if node.rel == 'root' or node.rel == 'ROOT': + root = node.id + break + return root + + def fix_oob_heads(self): + ''' + Fixes heads of the tree (if they dont exist, if they are out of bounds, etc) + If a head is out of bounds set it to nullhead + ''' + for node in self.nodes: + if node.head==D_NULLHEAD: + continue + if int(node.head) < 0: + node.head = D_NULLHEAD + elif int(node.head) > len(self.nodes): + node.head = D_NULLHEAD + + def fix_cycles(self, root): + ''' + Breaks cycles in the tree by setting the head of the node to root_id + ''' + for node in self.nodes: + visited = [] + + while (node.id != root) and (node.head !=D_NULLHEAD): + if node in visited: + node.head = D_NULLHEAD + else: + visited.append(node) + next_node = min(max(node.head-1, 0), len(self.nodes)-1) + node = self.nodes[next_node] + +# python related functions + def __repr__(self): + return "".join(str(e) for e in self.nodes)+"\n" + + def __iter__(self): + for n in self.nodes: + yield n + + def __getitem__(self, key): + return self.nodes[key] + + def __len__(self): + return self.nodes.__len__() + +# base tree + @staticmethod + def empty_tree(l=1): + ''' + Creates an empty dependency tree with l nodes + ''' + t = D_Tree([]) + for i in range(l): + n = D_Node.empty_node() + n.id = i + t.append_node(n) + return t + +# reader and writter + @staticmethod + def from_string(conll_str, dummy_root=True, clean_contractions=True, clean_omisions=True): + ''' + Create a ConllTree from a dependency tree conll-u string. + ''' + data = conll_str.split('\n') + dependency_tree_start_index = 0 + for line in data: + if len(line)>0 and line[0]!="#": + break + dependency_tree_start_index+=1 + data = data[dependency_tree_start_index:] + nodes = [] + if dummy_root: + nodes.append(D_Node.dummy_root()) + + for line in data: + # check if not valid line (empty or not enough fields) + if (len(line)<=1) or len(line.split('\t'))<10: + continue + + wid = line.split('\t')[0] + + # check if node is a comment (comments are marked with #) + if "#" in wid: + continue + + # check if node is a contraction (multiexp lines are marked with .) + if clean_contractions and "-" in wid: + continue + + # check if node is an omited word (empty nodes are marked with .) + if clean_omisions and "." in wid: + continue + + conll_node = D_Node.from_string(line) + nodes.append(conll_node) + + return D_Tree(nodes) + + @staticmethod + def read_conllu_file(file_path, filter_projective = True): + ''' + Read a conllu file and return a list of ConllTree objects. + ''' + with open(file_path, 'r') as f: + data = f.read() + data = data.split('\n\n') + # remove last empty line + data = data[:-1] + + trees = [] + for x in data: + t = D_Tree.from_string(x) + if not filter_projective or t.is_projective(): + trees.append(t) + return trees + + @staticmethod + def write_conllu_file(file_path, trees): + ''' + Write a list of ConllTree objects to a conllu file. + ''' + with open(file_path, 'w') as f: + f.write("".join(str(e) for e in trees)) + + @staticmethod + def write_conllu(file_io, tree): + ''' + Write a single ConllTree to a already open file. + Includes the # text = ... line + ''' + file_io.write("# text = "+tree.get_sentence()+"\n") + file_io.write("".join(str(e) for e in tree)+"\n") \ No newline at end of file diff --git a/tania_scripts/supar/codelin/models/linearized_tree.py b/tania_scripts/supar/codelin/models/linearized_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..1619b315c986e2c682f1e8d474f053653dd86558 --- /dev/null +++ b/tania_scripts/supar/codelin/models/linearized_tree.py @@ -0,0 +1,179 @@ +from supar.codelin.utils.constants import BOS, EOS, C_NO_POSTAG_LABEL +from supar.codelin.models.const_label import C_Label +from supar.codelin.models.deps_label import D_Label + +class LinearizedTree: + def __init__(self, words, postags, additional_feats, labels, n_feats): + self.words = words + self.postags = postags + self.additional_feats = additional_feats + self.labels = labels + #len(f_idx_dict.keys()) = n_feats + + def get_sentence(self): + return "".join(self.words) + + def get_labels(self): + return self.labels + + def get_word(self, index): + return self.words[index] + + def get_postag(self, index): + return self.postags[index] + + def get_additional_feat(self, index): + return self.additional_feats[index] if len(self.additional_feats) > 0 else None + + def get_label(self, index): + return self.labels[index] + + def reverse_tree(self, ignore_bos_eos=True): + ''' + Reverses the lists of words, postags, additional_feats and labels. + Do not reverses the first (BOS) and last (EOS) elements + ''' + if ignore_bos_eos: + self.words = self.words[1:-1][::-1] + self.postags = self.postags[1:-1][::-1] + self.additional_feats = self.additional_feats[1:-1][::-1] + self.labels = self.labels[1:-1][::-1] + else: + self.words = self.words[::-1] + self.postags = self.postags[::-1] + self.additional_feats = self.additional_feats[::-1] + self.labels = self.labels[::-1] + + def add_row(self, word, postag, additional_feat, label): + self.words.append(word) + self.postags.append(postag) + self.additional_feats.append(additional_feat) + self.labels.append(label) + + def iterrows(self): + for i in range(len(self)): + yield self.get_word(i), self.get_postag(i), self.get_additional_feat(i), self.get_label(i) + + def __len__(self): + return len(self.words) + + def __repr__(self): + return self.to_string() + + def to_string(self, f_idx_dict=None, add_bos_eos=True): + if add_bos_eos: + self.words = [BOS] + self.words + [EOS] + self.postags = [BOS] + self.postags + [EOS] + if f_idx_dict: + self.additional_feats = [len(f_idx_dict.keys()) * [BOS]] + self.additional_feats + [len(f_idx_dict.keys()) * [EOS]] + else: + self.additional_feats = [] + + self.labels = [BOS] + self.labels + [EOS] + + tree_string = "" + for w, p, af, l in self.iterrows(): + # create the output line of the linearized tree + output_line = [w,p] + + # check for features + if f_idx_dict: + if w == BOS: + f_list = [BOS] * (len(f_idx_dict.keys())+1) + elif w == EOS: + f_list = [EOS] * (len(f_idx_dict.keys())+1) + else: + f_list = ["_"] * (len(f_idx_dict.keys())+1) + + if af != [None]: + for element in af: + key, value = element.split("=", 1) if len(element.split("=",1))==2 else (None, None) + if key in f_idx_dict.keys(): + f_list[f_idx_dict[key]] = value + + # append the additional elements or the placehodler + for element in f_list: + output_line.append(element) + + # add the label + output_line.append(str(l)) + tree_string+=u"\t".join(output_line)+u"\n" + + if add_bos_eos: + self.words = self.words[1:-1] + self.postags = self.postags[1:-1] + if f_idx_dict: + self.additional_feats = self.additional_feats[len(f_idx_dict.keys()):-len(f_idx_dict.keys())] + self.labels = self.labels[1:-1] + + return tree_string + + @staticmethod + def empty_tree(n_feats = 1): + temp_tree = LinearizedTree(labels=[], words=[], postags=[], additional_feats=[], n_feats=n_feats) + return temp_tree + + @staticmethod + def from_string(content, mode, separator="_", unary_joiner="|", n_features=0): + ''' + Reads a linearized tree from a string shaped as + -BOS- \t -BOS- \t (...) \t -BOS- \n + word \t postag \t (...) \t label \n + word \t postag \t (...) \t label \n + -EOS- \t -EOS- \t (...) \t -EOS- \n + ''' + labels = [] + words = [] + postags = [] + additional_feats = [] + + linearized_tree = None + for line in content.split("\n"): + if line=="\n": + print("Empty line") + # skip empty line + if len(line) <= 1: + continue + + # Separate the label file into columns + line_columns = line.split("\t") if ("\t") in line else line.split(" ") + word = line_columns[0] + + if BOS == word: + labels = [] + words = [] + postags = [] + additional_feats = [] + + continue + + if EOS == word: + linearized_tree = LinearizedTree(words, postags, additional_feats, labels, n_features) + continue + + if len(line_columns) == 2: + word, label = line_columns + postag = C_NO_POSTAG_LABEL + feats = "_" + elif len(line_columns) == 3: + word, postag, label = line_columns[0], line_columns[1], line_columns[2] + feats = "_" + else: + word, postag, *feats, label = line_columns[0], line_columns[1], line_columns[1:-1], line_columns[-1] + + # Check for predictions with no label + if BOS in label or EOS in label: + label = "1"+separator+"ROOT" + + words.append(word) + postags.append(postag) + if mode == "CONST": + labels.append(C_Label.from_string(label, separator, unary_joiner)) + elif mode == "DEPS": + labels.append(D_Label.from_string(label, separator)) + else: + raise ValueError("[!] Unknown mode: %s" % mode) + + additional_feats.append(feats) + + return linearized_tree \ No newline at end of file diff --git a/tania_scripts/supar/codelin/utils/__init__.py b/tania_scripts/supar/codelin/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e624705816d7181d4f932b5d47cf136958510862 Binary files /dev/null and b/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d9b483a30076ad9fcf5cd79e963435945670210 Binary files /dev/null and b/tania_scripts/supar/codelin/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-310.pyc b/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..653174b605a0b2dafa036c155f450c26128665e9 Binary files /dev/null and b/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-310.pyc differ diff --git a/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-311.pyc b/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee60e7632471dbdadeb0deae1b865ffba40f7ef2 Binary files /dev/null and b/tania_scripts/supar/codelin/utils/__pycache__/constants.cpython-311.pyc differ diff --git a/tania_scripts/supar/codelin/utils/constants.py b/tania_scripts/supar/codelin/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..66b0745e63dcd603f3f81dc61198c047bdd53ac8 --- /dev/null +++ b/tania_scripts/supar/codelin/utils/constants.py @@ -0,0 +1,54 @@ +# COMMON CONSTANTS + +EOS = "-EOS-" +BOS = "-BOS-" + +F_CONSTITUENT = "CONST" +F_DEPENDENCY = "DEPS" + +OP_ENC = "ENC" +OP_DEC = "DEC" + +# CONSTITUENT ENCODINGS + +C_ABSOLUTE_ENCODING = 'ABS' +C_RELATIVE_ENCODING = 'REL' +C_DYNAMIC_ENCODING = 'DYN' +C_INCREMENTAL_ENCODING = 'INC' + +C_STRAT_FIRST="strat_first" +C_STRAT_LAST="strat_last" +C_STRAT_MAX="strat_max" +C_STRAT_NONE="strat_none" + +# CONSTITUENT MISC + +C_NONE_LABEL = "-NONE-" +C_NO_POSTAG_LABEL = "-NOPOS-" +C_ROOT_LABEL = "-ROOT-" +C_END_LABEL = "-END-" +C_START_LABEL = "-START-" +C_CONFLICT_SEPARATOR = "-||-" +C_DUMMY_END = "DUMMY_END" + +# DEPENDENCIY ENCODINGS + +D_NONE_LABEL = "-NONE-" + +D_ABSOLUTE_ENCODING = 'ABS' +D_RELATIVE_ENCODING = 'REL' +D_POS_ENCODING = 'POS' +D_BRACKET_ENCODING = 'BRK' +D_BRACKET_ENCODING_2P = 'BRK_2P' + +D_2P_GREED = 'GREED' +D_2P_PROP = 'PROPAGATE' + +# DEPENDENCY MISC + +D_EMPTYREL = "-NOREL-" +D_POSROOT = "-ROOT-" +D_NULLHEAD = "-NULL-" + +D_ROOT_HEAD = "strat_gethead" +D_ROOT_REL = "strat_getrel" diff --git a/tania_scripts/supar/codelin/utils/extract_feats.py b/tania_scripts/supar/codelin/utils/extract_feats.py new file mode 100644 index 0000000000000000000000000000000000000000..be867de99e84f24d9cc35366f9904502011c3be9 --- /dev/null +++ b/tania_scripts/supar/codelin/utils/extract_feats.py @@ -0,0 +1,41 @@ +import argparse +from supar.codelin.models.const_tree import C_Tree +from supar.codelin.models.deps_tree import D_Tree + +def extract_features_const(in_path): + file_in = open(in_path, "r") + feats_set = set() + for line in file_in: + line = line.rstrip() + tree = C_Tree.from_string(line) + tree.extract_features() + feats = tree.get_feature_names() + + feats_set = feats_set.union(feats) + + return sorted(feats_set) + +def extract_features_deps(in_path): + feats_list=set() + trees = D_Tree.read_conllu_file(in_path, filter_projective=False) + for t in trees: + for node in t: + if node.feats != "_": + feats_list = feats_list.union(a for a in (node.feats.keys())) + return sorted(feats_list) + + +''' +Python script that returns a ordered list of the features +included in a conll tree or a constituent tree +''' + +# parser = argparse.ArgumentParser(description='Prints all features in a constituent treebank') +# parser.add_argument('form', metavar='formalism', type=str, choices=['CONST','DEPS'], help='Grammar encoding the file to extract features') +# parser.add_argument('input', metavar='in file', type=str, help='Path of the file to clean (.trees file).') +# args = parser.parse_args() +# if args.form=='CONST': +# feats = extract_features_const(args.input) +# elif args.form=='DEPS': +# feats = extract_features_conll(args.input) +# print(" ".join(feats)) diff --git a/tania_scripts/supar/model.py b/tania_scripts/supar/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f2d52770165cede33df41efe7dd6aace5dd87452 --- /dev/null +++ b/tania_scripts/supar/model.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from supar.modules import (CharLSTM, ELMoEmbedding, IndependentDropout, + SharedDropout, TransformerEmbedding, + TransformerWordEmbedding, VariationalLSTM) +from supar.modules.transformer import (TransformerEncoder, + TransformerEncoderLayer) +from supar.utils import Config +#from supar.vector_quantize import VectorQuantize + +from vector_quantize_pytorch import VectorQuantize + + +class Model(nn.Module): + + def __init__(self, + n_words, + n_tags=None, + n_chars=None, + n_lemmas=None, + encoder='lstm', + feat=['tag', 'char'], + n_embed=300, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=45, + n_char_hidden=100, + char_pad_index=0, + char_dropout=0, + elmo_bos_eos=(True, True), + elmo_dropout=0.5, + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + encoder_dropout=.33, + pad_index=0, + + n_decoder_layers: int = 2, + bidirectional: bool = False, + + use_vq=False, + vq_passes: int = 300, + codebook_size=512, + vq_decay=0.3, + commitment_weight=3, + + **kwargs): + super().__init__() + + self.args = Config().update(locals()) + + if encoder == 'lstm': + self.word_embed = nn.Embedding(num_embeddings=self.args.n_words, + embedding_dim=self.args.n_embed) + + n_input = self.args.n_embed + + if self.args.n_pretrained != self.args.n_embed: + n_input += self.args.n_pretrained + if 'tag' in self.args.feat: + self.tag_embed = nn.Embedding(num_embeddings=self.args.n_tags, + embedding_dim=self.args.n_feat_embed) + n_input += self.args.n_feat_embed + + if 'char' in self.args.feat: + self.char_embed = CharLSTM(n_chars=self.args.n_chars, + n_embed=self.args.n_char_embed, + n_hidden=self.args.n_char_hidden, + n_out=self.args.n_feat_embed, + pad_index=self.args.char_pad_index, + dropout=self.args.char_dropout) + n_input += self.args.n_feat_embed + + + if 'lemma' in self.args.feat: + self.lemma_embed = nn.Embedding(num_embeddings=self.args.n_lemmas, + embedding_dim=self.args.n_feat_embed) + n_input += self.args.n_feat_embed + if 'elmo' in self.args.feat: + self.elmo_embed = ELMoEmbedding(n_out=self.args.n_plm_embed, + bos_eos=self.args.elmo_bos_eos, + dropout=self.args.elmo_dropout, + finetune=self.args.finetune) + n_input += self.elmo_embed.n_out + if 'bert' in self.args.feat: + self.bert_embed = TransformerEmbedding(name=self.args.bert, + n_layers=self.args.n_bert_layers, + n_out=self.args.n_plm_embed, + pooling=self.args.bert_pooling, + pad_index=self.args.bert_pad_index, + mix_dropout=self.args.mix_dropout, + finetune=self.args.finetune) + n_input += self.bert_embed.n_out + self.embed_dropout = IndependentDropout(p=self.args.embed_dropout) + self.encoder = VariationalLSTM( + input_size=n_input, + hidden_size=self.args.n_encoder_hidden//2 if self.args.bidirectional else self.args.n_encoder_hidden, + num_layers=self.args.n_encoder_layers, bidirectional=self.args.bidirectional, + dropout=self.args.encoder_dropout) + self.encoder_dropout = SharedDropout(p=self.args.encoder_dropout) + elif encoder == 'transformer': + self.word_embed = TransformerWordEmbedding(n_vocab=self.args.n_words, + n_embed=self.args.n_embed, + pos=self.args.pos, + pad_index=self.args.pad_index) + self.embed_dropout = nn.Dropout(p=self.args.embed_dropout) + self.encoder = TransformerEncoder(layer=TransformerEncoderLayer(n_heads=self.args.n_encoder_heads, + n_model=self.args.n_encoder_hidden, + n_inner=self.args.n_encoder_inner, + attn_dropout=self.args.encoder_attn_dropout, + ffn_dropout=self.args.encoder_ffn_dropout, + dropout=self.args.encoder_dropout), + n_layers=self.args.n_encoder_layers, + n_model=self.args.n_encoder_hidden) + self.encoder_dropout = nn.Dropout(p=self.args.encoder_dropout) + elif encoder == 'bert': + self.encoder = TransformerEmbedding(name=self.args.bert, + n_layers=self.args.n_bert_layers, + n_out=self.args.n_encoder_hidden, + pooling=self.args.bert_pooling, + pad_index=self.args.pad_index, + mix_dropout=self.args.mix_dropout, + finetune=self.args.finetune) + self.encoder_dropout = nn.Dropout(p=self.args.encoder_dropout) + self.args.n_encoder_hidden = self.encoder.n_out + + self.passes_remaining = vq_passes + if use_vq: + self.vq = VectorQuantize(dim=self.args.n_encoder_hidden, codebook_size=codebook_size, decay=vq_decay, + commitment_weight=commitment_weight, eps=1e-5) #, wait_steps=0, observe_steps=vq_passes) + else: + self.vq = nn.Identity() + + def load_pretrained(self, embed=None): + if embed is not None: + self.pretrained = nn.Embedding.from_pretrained(embed) + if embed.shape[1] != self.args.n_pretrained: + self.embed_proj = nn.Linear(embed.shape[1], self.args.n_pretrained) + nn.init.zeros_(self.word_embed.weight) + return self + + def forward(self): + raise NotImplementedError + + def vq_forward(self, x: torch.Tensor): + if not self.args.use_vq: + return x, torch.tensor(0) + + if self.passes_remaining > (self.args.vq_passes / 2): + _, _, commit_loss, _ = self.vq(x) + self.passes_remaining -= 1 + elif 0 < self.passes_remaining < (self.args.vq_passes / 2): + x_quantized, _, commit_loss, _ = self.vq(x) + x = torch.lerp(x, x_quantized, (self.passes_remaining - self.passes) / self.passes) + self.passes_remaining -= 1 + else: + x, _, commit_loss, _ = self.vq(x) + qloss = commit_loss.squeeze() + + return x, qloss + + def loss(self): + raise NotImplementedError + + def embed(self, words, feats=None):#feats=None + ext_words = words + # set the indices larger than num_embeddings to unk_index + if hasattr(self, 'pretrained'): + ext_mask = words.ge(self.word_embed.num_embeddings) + ext_words = words.masked_fill(ext_mask, self.args.unk_index) + + # get outputs from embedding layers + word_embed = self.word_embed(ext_words) + if hasattr(self, 'pretrained'): + pretrained = self.pretrained(words) + if self.args.n_embed == self.args.n_pretrained: + word_embed += pretrained + else: + word_embed = torch.cat((word_embed, self.embed_proj(pretrained)), -1) + feat_embed = [] + + if 'tag' in self.args.feat: + feat_embed.append(self.tag_embed(feats.pop(0))) + if 'char' in self.args.feat: + + feat_embed.append(self.char_embed(feats.pop(0))) + if 'elmo' in self.args.feat: + feat_embed.append(self.elmo_embed(feats.pop(0))) + if 'bert' in self.args.feat: + feat_embed.append(self.bert_embed(feats.pop(0))) + if 'lemma' in self.args.feat: + feat_embed.append(self.lemma_embed(feats.pop(0))) + if isinstance(self.embed_dropout, IndependentDropout): + if len(feat_embed) == 0: + raise RuntimeError(f"`feat` is not allowed to be empty, which is {self.args.feat} now") + embed = torch.cat(self.embed_dropout(word_embed, torch.cat(feat_embed, -1)), -1) + else: + embed = word_embed + if len(feat_embed) > 0: + embed = torch.cat((embed, torch.cat(feat_embed, -1)), -1) + embed = self.embed_dropout(embed) + return embed + + def encode(self, words, feats):#=None): + if self.args.encoder == 'lstm': + x = pack_padded_sequence(self.embed(words, feats), words.ne(self.args.pad_index).sum(1).tolist(), True, False) + + x, _ = self.encoder(x) + x, _ = pad_packed_sequence(x, True, total_length=words.shape[1]) + elif self.args.encoder == 'transformer': + x = self.encoder(self.embed(words, feats), words.ne(self.args.pad_index)) + else: + x = self.encoder(words) + return self.encoder_dropout(x) + + def decode(self): + raise NotImplementedError + + def vq_forward(self, x: torch.Tensor): + if not self.args.use_vq: + return x, torch.tensor(0) + + if self.passes_remaining > (self.args.vq_passes / 2): + _, _, commit_loss = self.vq(x) + self.passes_remaining -= 1 + elif 0 < self.passes_remaining < (self.args.vq_passes / 2): + x_quantized, _, commit_loss = self.vq(x) + x = torch.lerp(x, x_quantized, (self.passes_remaining - self.passes) / self.passes) + self.passes_remaining -= 1 + else: + x, indices, commit_loss = self.vq(x) + qloss = commit_loss.squeeze() + return x, qloss + + @property + def device(self): + if self.args.device == 'cpu': + return 'cpu' + else: + return f'cuda:{self.args.device}' diff --git a/tania_scripts/supar/models/.ipynb_checkpoints/__init__-checkpoint.py b/tania_scripts/supar/models/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e54ae9fad6a968ecbdb484b334c299ce00f07d8e --- /dev/null +++ b/tania_scripts/supar/models/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +from .const import (AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos, CRFConstituencyParser, + TetraTaggingConstituencyParser, VIConstituencyParser, SLConstituentParser) +from .dep import (BiaffineDependencyParser, CRF2oDependencyParser, + CRFDependencyParser, VIDependencyParser, + SLDependencyParser, ArcEagerDependencyParser) +from .sdp import BiaffineSemanticDependencyParser, VISemanticDependencyParser + +__all__ = ['BiaffineDependencyParser', + 'CRFDependencyParser', + 'CRF2oDependencyParser', + 'VIDependencyParser', + 'AttachJuxtaposeConstituencyParser', + 'CRFConstituencyParser', + 'TetraTaggingConstituencyParser', + 'VIConstituencyParser', + 'SLConstituentParser', + 'BiaffineSemanticDependencyParser', + 'VISemanticDependencyParser', + 'SLDependencyParser', + 'ArcEagerDependencyParser'] \ No newline at end of file diff --git a/tania_scripts/supar/models/__init__.py b/tania_scripts/supar/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e54ae9fad6a968ecbdb484b334c299ce00f07d8e --- /dev/null +++ b/tania_scripts/supar/models/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +from .const import (AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos, CRFConstituencyParser, + TetraTaggingConstituencyParser, VIConstituencyParser, SLConstituentParser) +from .dep import (BiaffineDependencyParser, CRF2oDependencyParser, + CRFDependencyParser, VIDependencyParser, + SLDependencyParser, ArcEagerDependencyParser) +from .sdp import BiaffineSemanticDependencyParser, VISemanticDependencyParser + +__all__ = ['BiaffineDependencyParser', + 'CRFDependencyParser', + 'CRF2oDependencyParser', + 'VIDependencyParser', + 'AttachJuxtaposeConstituencyParser', + 'CRFConstituencyParser', + 'TetraTaggingConstituencyParser', + 'VIConstituencyParser', + 'SLConstituentParser', + 'BiaffineSemanticDependencyParser', + 'VISemanticDependencyParser', + 'SLDependencyParser', + 'ArcEagerDependencyParser'] \ No newline at end of file diff --git a/tania_scripts/supar/models/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c56ad26f856966acabe985c3781a35035bb4d79 Binary files /dev/null and b/tania_scripts/supar/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9521805749c757cbf900e8dbd4ebb1339359dd88 Binary files /dev/null and b/tania_scripts/supar/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/__pycache__/__init__.cpython-39.pyc b/tania_scripts/supar/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ca2e6217fc9a781a4fd1f5df3b0ce2c9a6ddf79 Binary files /dev/null and b/tania_scripts/supar/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/tania_scripts/supar/models/const/__init__.py b/tania_scripts/supar/models/const/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a15c165338e852b5d02ed348db763d43da09628 --- /dev/null +++ b/tania_scripts/supar/models/const/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from .aj import (AttachJuxtaposeConstituencyModel,AttachJuxtaposeConstituencyModelPos, + AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos) +from .crf import CRFConstituencyModel, CRFConstituencyParser +from .tt import TetraTaggingConstituencyModel, TetraTaggingConstituencyParser +from .vi import VIConstituencyModel, VIConstituencyParser +from .sl import SLConstituentParser, SLConstituentModel + +__all__ = ['AttachJuxtaposeConstituencyModel', 'AttachJuxtaposeConstituencyModelPos', 'AttachJuxtaposeConstituencyParser', 'AttachJuxtaposeConstituencyParserPos', + 'CRFConstituencyModel', 'CRFConstituencyParser', + 'TetraTaggingConstituencyModel', 'TetraTaggingConstituencyParser', + 'VIConstituencyModel', 'VIConstituencyParser', + 'SLConstituentModel', 'SLConstituentParser'] diff --git a/tania_scripts/supar/models/const/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d250b33146033de8f6e222867a4f01c6f2d8a54 Binary files /dev/null and b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24326643d82fa6033432ee9bbd7ebf061b74a4d0 Binary files /dev/null and b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/__pycache__/__init__.cpython-39.pyc b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19a59b8eb7ae83aa8e4089626004ec916d1b1ec8 Binary files /dev/null and b/tania_scripts/supar/models/const/__pycache__/__init__.cpython-39.pyc differ diff --git a/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/__init__-checkpoint.py b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1fe4855b790afea3b677caaaa7de5f6ac3549216 --- /dev/null +++ b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import AttachJuxtaposeConstituencyModel, AttachJuxtaposeConstituencyModelPos +from .parser import AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos + +__all__ = ['AttachJuxtaposeConstituencyModel', 'AttachJuxtaposeConstituencyModelPos', 'AttachJuxtaposeConstituencyParser', 'AttachJuxtaposeConstituencyParserPos'] diff --git a/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/model-checkpoint.py b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/model-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..30de07c6de4c6e6c60b32a31a94aa0a29040d8ff --- /dev/null +++ b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/model-checkpoint.py @@ -0,0 +1,786 @@ +# -*- coding: utf-8 -*- + +from typing import List, Tuple + +import torch +import torch.nn as nn +from supar.model import Model +from supar.models.const.aj.transform import AttachJuxtaposeTree +from supar.modules import GraphConvolutionalNetwork, MLP, DecoderLSTM +from supar.utils import Config +from supar.utils.common import INF +from supar.utils.fn import pad + +class DecoderLSTMPos(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, device): + super().__init__() + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, + batch_first=True, dropout=dropout) + self.classifier = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + # x: [batch_size, seq_len, input_dim] + output, _ = self.lstm(x) + logits = self.classifier(output) + return logits + +class AttachJuxtaposeConstituencyModel(Model): + r""" + The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`. + + Args: + n_words (int): + The size of the word vocabulary. + n_labels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layers. Default: .33. + n_gnn_layers (int): + The number of GNN layers. Default: 3. + gnn_dropout (float): + The dropout ratio of GNN layers. Default: .33. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_labels, + n_tags=None, + n_chars=None, + encoder='lstm', + feat=['char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, True), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_gnn_layers=3, + gnn_dropout=.33, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + # the last one represents the dummy node in the initial states + self.label_embed = nn.Embedding(n_labels+1, self.args.n_encoder_hidden) + self.gnn_layers = GraphConvolutionalNetwork(n_model=self.args.n_encoder_hidden, + n_layers=self.args.n_gnn_layers, + dropout=self.args.gnn_dropout) + + self.node_classifier = nn.Sequential( + nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2), + nn.LayerNorm(self.args.n_encoder_hidden // 2), + nn.ReLU(), + nn.Linear(self.args.n_encoder_hidden // 2, 1), + ) + self.label_classifier = nn.Sequential( + nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2), + nn.LayerNorm(self.args.n_encoder_hidden // 2), + nn.ReLU(), + nn.Linear(self.args.n_encoder_hidden // 2, 2 * n_labels), + ) + + # create delay projection + if self.args.delay != 0: + self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay+1), + n_out=self.args.n_encoder_hidden, dropout=gnn_dropout) + + self.criterion = nn.CrossEntropyLoss() + + def forward( + self, + words: torch.LongTensor, + feats: List[torch.LongTensor] + ) -> Tuple[torch.Tensor]: + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor: + Contextualized output hidden states of shape ``[batch_size, seq_len, n_model]`` of the input. + """ + x = self.encode(words, feats) + + # adjust lengths to allow delay predictions + if self.args.delay != 0: + x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i)] for i in range(self.args.delay + 1)], dim=2) + x = self.delay_proj(x) + + # pass through vector quantization + x, qloss = self.vq_forward(x) + + return x, qloss + + def loss( + self, + x: torch.Tensor, + nodes: torch.LongTensor, + parents: torch.LongTensor, + news: torch.LongTensor, + mask: torch.BoolTensor + ) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, n_model]``. + Contextualized output hidden states. + nodes (~torch.LongTensor): ``[batch_size, seq_len]``. + The target node positions on rightmost chains. + parents (~torch.LongTensor): ``[batch_size, seq_len]``. + The parent node labels of terminals. + news (~torch.LongTensor): ``[batch_size, seq_len]``. + The parent node labels of juxtaposed targets and terminals. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + + Returns: + ~torch.Tensor: + The training loss. + """ + + spans, s_node, x_node = None, [], [] + actions = torch.stack((nodes, parents, news)) + for t, action in enumerate(actions.unbind(-1)): + if t == 0: + x_span = self.label_embed(actions.new_full((x.shape[0], 1), self.args.n_labels)) + span_mask = mask[:, :1] + else: + x_span = self.rightmost_chain(x, spans, mask, t) + span_lens = spans[:, :-1, -1].ge(0).sum(-1) + span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + x_rightmost = torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1) + s_node.append(self.node_classifier(x_rightmost).squeeze(-1)) + # we found softmax is slightly better than sigmoid in the original paper + s_node[-1] = s_node[-1].masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0) + x_node.append(torch.bmm(s_node[-1].softmax(-1).unsqueeze(1), x_span).squeeze(1)) + spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t]) + attach_mask = x.new_tensor(range(self.args.n_labels)).eq(self.args.nul_index) + s_node, x_node = pad(s_node, -INF).transpose(0, 1), torch.stack(x_node, 1) + s_parent, s_new = self.label_classifier(torch.cat((x, x_node), -1)).chunk(2, -1) + s_parent = torch.cat((s_parent[:, :1].masked_fill(attach_mask, -INF), s_parent[:, 1:]), 1) + s_new = torch.cat((s_new[:, :1].masked_fill(~attach_mask, -INF), s_new[:, 1:]), 1) + node_loss = self.criterion(s_node[mask], nodes[mask]) + label_loss = self.criterion(s_parent[mask], parents[mask]) + self.criterion(s_new[mask], news[mask]) + return node_loss + label_loss + + def decode( + self, + x: torch.Tensor, + mask: torch.BoolTensor, + beam_size: int = 1 + ) -> List[List[Tuple]]: + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, n_model]``. + Contextualized output hidden states. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + beam_size (int): + Beam size for decoding. Default: 1. + + Returns: + List[List[Tuple]]: + Sequences of factorized labeled trees. + """ + tokenwise_predictions = [] + spans = None + batch_size, *_ = x.shape + n_labels = self.args.n_labels + # [batch_size * beam_size, ...] + x = x.unsqueeze(1).repeat(1, beam_size, 1, 1).view(-1, *x.shape[1:]) + mask = mask.unsqueeze(1).repeat(1, beam_size, 1).view(-1, *mask.shape[1:]) + # [batch_size] + batches = x.new_tensor(range(batch_size)).long() * beam_size + # accumulated scores + scores = x.new_full((batch_size, beam_size), -INF).index_fill_(-1, x.new_tensor(0).long(), 0).view(-1) + for t in range(x.shape[1]): + if t == 0: + x_span = self.label_embed(batches.new_full((x.shape[0], 1), n_labels)) + span_mask = mask[:, :1] + else: + x_span = self.rightmost_chain(x, spans, mask, t) + span_lens = spans[:, :-1, -1].ge(0).sum(-1) + span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + s_node = self.node_classifier(torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)).squeeze(-1) + s_node = s_node.masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0).log_softmax(-1) + # we found softmax is slightly better than sigmoid in the original paper + x_node = torch.bmm(s_node.exp().unsqueeze(1), x_span).squeeze(1) + s_parent, s_new = self.label_classifier(torch.cat((x[:, t], x_node), -1)).chunk(2, -1) + s_parent, s_new = s_parent.log_softmax(-1), s_new.log_softmax(-1) + if t == 0: + s_parent[:, self.args.nul_index] = -INF + s_new[:, s_new.new_tensor(range(self.args.n_labels)).ne(self.args.nul_index)] = -INF + s_node, nodes = s_node.topk(min(s_node.shape[-1], beam_size), -1) + s_parent, parents = s_parent.topk(min(n_labels, beam_size), -1) + s_new, news = s_new.topk(min(n_labels, beam_size), -1) + s_action = s_node.unsqueeze(2) + (s_parent.unsqueeze(2) + s_new.unsqueeze(1)).view(x.shape[0], 1, -1) + s_action = s_action.view(x.shape[0], -1) + k_beam, k_node, k_parent = s_action.shape[-1], parents.shape[-1] * news.shape[-1], news.shape[-1] + # [batch_size * beam_size, k_beam] + scores = scores.unsqueeze(-1) + s_action + # [batch_size, beam_size] + scores, cands = scores.view(batch_size, -1).topk(beam_size, -1) + # [batch_size * beam_size] + scores = scores.view(-1) + beams = cands.div(k_beam, rounding_mode='floor') + nodes = nodes.view(batch_size, -1).gather(-1, cands.div(k_node, rounding_mode='floor')) + indices = (batches.unsqueeze(-1) + beams).view(-1) + + #print('indices', indices) + parents = parents[indices].view(batch_size, -1).gather(-1, cands.div(k_parent, rounding_mode='floor') % k_parent) + news = news[indices].view(batch_size, -1).gather(-1, cands % k_parent) + action = torch.stack((nodes, parents, news)).view(3, -1) + tokenwise_predictions.append([t, [x[0] for x in action.tolist()]]) + spans = spans[indices] if spans is not None else None + spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t]) + #print("SPANS", spans) + mask = mask.view(batch_size, beam_size, -1)[:, 0] + # select an 1-best tree for each sentence + spans = spans[batches + scores.view(batch_size, -1).argmax(-1)] + span_mask = spans.ge(0) + span_indices = torch.where(span_mask) + span_labels = spans[span_indices] + chart_preds = [[] for _ in range(x.shape[0])] + for i, *span in zip(*[s.tolist() for s in span_indices], span_labels.tolist()): + chart_preds[i].append(span) + kk = [chart_preds + tokenwise_predictions] + return kk + + def rightmost_chain( + self, + x: torch.Tensor, + spans: torch.LongTensor, + mask: torch.BoolTensor, + t: int + ) -> torch.Tensor: + x_p, mask_p = x[:, :t], mask[:, :t] + lens = mask_p.sum(-1) + span_mask = spans[:, :-1, 1:].ge(0) + span_lens = span_mask.sum((-1, -2)) + span_indices = torch.where(span_mask) + span_labels = spans[:, :-1, 1:][span_indices] + x_span = self.label_embed(span_labels) + x_span += x[span_indices[0], span_indices[1]] + x[span_indices[0], span_indices[2]] + node_lens = lens + span_lens + adj_mask = node_lens.unsqueeze(-1).gt(x.new_tensor(range(node_lens.max()))) + x_mask = lens.unsqueeze(-1).gt(x.new_tensor(range(adj_mask.shape[-1]))) + span_mask = ~x_mask & adj_mask + # concatenate terminals and spans + x_tree = x.new_zeros(*adj_mask.shape, x.shape[-1]).masked_scatter_(x_mask.unsqueeze(-1), x_p[mask_p]) + x_tree = x_tree.masked_scatter_(span_mask.unsqueeze(-1), x_span) + adj = mask.new_zeros(*x_tree.shape[:-1], x_tree.shape[1]) + adj_spans = lens.new_tensor(range(x_tree.shape[1])).view(1, 1, -1).repeat(2, x.shape[0], 1) + adj_spans = adj_spans.masked_scatter_(span_mask.unsqueeze(0), torch.stack(span_indices[1:])) + adj_l, adj_r, adj_w = *adj_spans.unbind(), adj_spans[1] - adj_spans[0] + adj_parent = adj_l.unsqueeze(-1).ge(adj_l.unsqueeze(-2)) & adj_r.unsqueeze(-1).le(adj_r.unsqueeze(-2)) + # set the parent of root as itself + adj_parent.diagonal(0, 1, 2).copy_(adj_w.eq(t - 1)) + adj_parent = adj_parent & span_mask.unsqueeze(1) + # closet ancestor spans as parents + adj_parent = (adj_w.unsqueeze(-2) - adj_w.unsqueeze(-1)).masked_fill_(~adj_parent, t).argmin(-1) + adj.scatter_(-1, adj_parent.unsqueeze(-1), 1) + adj = (adj | adj.transpose(-1, -2)).float() + x_tree = self.gnn_layers(x_tree, adj, adj_mask) + span_mask = span_mask.masked_scatter(span_mask, span_indices[2].eq(t - 1)) + span_lens = span_mask.sum(-1) + x_tree, span_mask = x_tree[span_mask], span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + x_span = x.new_zeros(*span_mask.shape, x.shape[-1]).masked_scatter_(span_mask.unsqueeze(-1), x_tree) + return x_span + + +class AttachJuxtaposeConstituencyModelPos(Model): + r""" + The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`. + + Args: + n_words (int): + The size of the word vocabulary. + n_labels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layers. Default: .33. + n_gnn_layers (int): + The number of GNN layers. Default: 3. + gnn_dropout (float): + The dropout ratio of GNN layers. Default: .33. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_labels, + n_tags=16, + n_chars=None, + encoder='lstm', + feat=['char', 'tag'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, True), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_gnn_layers=3, + gnn_dropout=.33, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + # the last one represents the dummy node in the initial states + self.label_embed = nn.Embedding(n_labels+1, self.args.n_encoder_hidden) + self.gnn_layers = GraphConvolutionalNetwork(n_model=self.args.n_encoder_hidden, + n_layers=self.args.n_gnn_layers, + dropout=self.args.gnn_dropout) + + self.node_classifier = nn.Sequential( + nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2), + nn.LayerNorm(self.args.n_encoder_hidden // 2), + nn.ReLU(), + nn.Linear(self.args.n_encoder_hidden // 2, 1), + ) + self.label_classifier = nn.Sequential( + nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2), + nn.LayerNorm(self.args.n_encoder_hidden // 2), + nn.ReLU(), + nn.Linear(self.args.n_encoder_hidden // 2, 2 * n_labels), + ) + self.pos_classifier = DecoderLSTMPos( + self.args.n_encoder_hidden, self.args.n_encoder_hidden, self.args.n_tags, + num_layers=1, dropout=encoder_dropout, device=self.device + ) + + #self.pos_tagger = nn.Identity() + # create delay projection + if self.args.delay != 0: + self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay+1), + n_out=self.args.n_encoder_hidden, dropout=gnn_dropout) + + self.criterion = nn.CrossEntropyLoss() + + def encoder_forward(self, words: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Applies encoding forward pass. Maps a tensor of word indices (`words`) to their corresponding neural + representation. + Args: + words: torch.IntTensor ~ [batch_size, bos + pad(seq_len) + eos + delay] + feats: List[torch.Tensor] + lens: List[int] + + Returns: x, qloss + x: torch.FloatTensor ~ [batch_size, bos + pad(seq_len) + eos, embed_dim] + qloss: torch.FloatTensor ~ 1 + + """ + + x = super().encode(words, feats) + s_tag = self.pos_classifier(x[:, 1:-(1+self.args.delay), :]) + + # adjust lengths to allow delay predictions + # x ~ [batch_size, bos + pad(seq_len) + eos, embed_dim] + if self.args.delay != 0: + x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i), :] for i in range(self.args.delay + 1)], dim=2) + x = self.delay_proj(x) + + # pass through vector quantization + x, qloss = self.vq_forward(x) + return x, s_tag, qloss + + + def forward( + self, + words: torch.LongTensor, + feats: List[torch.LongTensor] + ) -> Tuple[torch.Tensor]: + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor: + Contextualized output hidden states of shape ``[batch_size, seq_len, n_model]`` of the input. + """ + x, s_tag, qloss = self.encoder_forward(words, feats) + + return x, s_tag, qloss + + def loss( + self, + x: torch.Tensor, + nodes: torch.LongTensor, + parents: torch.LongTensor, + news: torch.LongTensor, + mask: torch.BoolTensor, s_tags: torch.LongTensor, tags: torch.LongTensor + ) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, n_model]``. + Contextualized output hidden states. + nodes (~torch.LongTensor): ``[batch_size, seq_len]``. + The target node positions on rightmost chains. + parents (~torch.LongTensor): ``[batch_size, seq_len]``. + The parent node labels of terminals. + news (~torch.LongTensor): ``[batch_size, seq_len]``. + The parent node labels of juxtaposed targets and terminals. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + + Returns: + ~torch.Tensor: + The training loss. + """ + + spans, s_node, x_node = None, [], [] + actions = torch.stack((nodes, parents, news)) + for t, action in enumerate(actions.unbind(-1)): + if t == 0: + x_span = self.label_embed(actions.new_full((x.shape[0], 1), self.args.n_labels)) + span_mask = mask[:, :1] + else: + x_span = self.rightmost_chain(x, spans, mask, t) + span_lens = spans[:, :-1, -1].ge(0).sum(-1) + span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + x_rightmost = torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1) + s_node.append(self.node_classifier(x_rightmost).squeeze(-1)) + # we found softmax is slightly better than sigmoid in the original paper + s_node[-1] = s_node[-1].masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0) + x_node.append(torch.bmm(s_node[-1].softmax(-1).unsqueeze(1), x_span).squeeze(1)) + spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t]) + attach_mask = x.new_tensor(range(self.args.n_labels)).eq(self.args.nul_index) + s_node, x_node = pad(s_node, -INF).transpose(0, 1), torch.stack(x_node, 1) + s_parent, s_new = self.label_classifier(torch.cat((x, x_node), -1)).chunk(2, -1) + #s_postag = self.pos_classifier(x[:, 1:-(1+self.args.delay), :]).chunk(2, -1) + s_parent = torch.cat((s_parent[:, :1].masked_fill(attach_mask, -INF), s_parent[:, 1:]), 1) + s_new = torch.cat((s_new[:, :1].masked_fill(~attach_mask, -INF), s_new[:, 1:]), 1) + node_loss = self.criterion(s_node[mask], nodes[mask]) + #print('node loss', node_loss) + label_loss = self.criterion(s_parent[mask], parents[mask]) + self.criterion(s_new[mask], news[mask]) + #print('label loss', label_loss) + + #print(s_tag[mask].shape, tags[mask].shape) + tag_loss = self.criterion(s_tags[mask], tags[mask]) + #print('tag loss', tag_loss) + #tag_loss = self.pos_loss(s_tags, tags, mask) + print("node loss, label loss, tag loss", node_loss, label_loss, tag_loss, node_loss + label_loss + tag_loss) + return node_loss + label_loss + tag_loss + + def decode( + self, + x: torch.Tensor, + mask: torch.BoolTensor, + beam_size: int = 1 + ) -> List[List[Tuple]]: + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, n_model]``. + Contextualized output hidden states. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + beam_size (int): + Beam size for decoding. Default: 1. + + Returns: + List[List[Tuple]]: + Sequences of factorized labeled trees. + """ + tokenwise_predictions = [] + spans = None + batch_size, *_ = x.shape + n_labels = self.args.n_labels + # [batch_size * beam_size, ...] + x = x.unsqueeze(1).repeat(1, beam_size, 1, 1).view(-1, *x.shape[1:]) + mask = mask.unsqueeze(1).repeat(1, beam_size, 1).view(-1, *mask.shape[1:]) + # [batch_size] + batches = x.new_tensor(range(batch_size)).long() * beam_size + # accumulated scores + scores = x.new_full((batch_size, beam_size), -INF).index_fill_(-1, x.new_tensor(0).long(), 0).view(-1) + for t in range(x.shape[1]): + if t == 0: + x_span = self.label_embed(batches.new_full((x.shape[0], 1), n_labels)) + span_mask = mask[:, :1] + else: + x_span = self.rightmost_chain(x, spans, mask, t) + span_lens = spans[:, :-1, -1].ge(0).sum(-1) + span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + s_node = self.node_classifier(torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)).squeeze(-1) + s_node = s_node.masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0).log_softmax(-1) + # we found softmax is slightly better than sigmoid in the original paper + x_node = torch.bmm(s_node.exp().unsqueeze(1), x_span).squeeze(1) + s_parent, s_new = self.label_classifier(torch.cat((x[:, t], x_node), -1)).chunk(2, -1) + s_parent, s_new = s_parent.log_softmax(-1), s_new.log_softmax(-1) + if t == 0: + s_parent[:, self.args.nul_index] = -INF + s_new[:, s_new.new_tensor(range(self.args.n_labels)).ne(self.args.nul_index)] = -INF + s_node, nodes = s_node.topk(min(s_node.shape[-1], beam_size), -1) + s_parent, parents = s_parent.topk(min(n_labels, beam_size), -1) + s_new, news = s_new.topk(min(n_labels, beam_size), -1) + s_action = s_node.unsqueeze(2) + (s_parent.unsqueeze(2) + s_new.unsqueeze(1)).view(x.shape[0], 1, -1) + s_action = s_action.view(x.shape[0], -1) + k_beam, k_node, k_parent = s_action.shape[-1], parents.shape[-1] * news.shape[-1], news.shape[-1] + # [batch_size * beam_size, k_beam] + scores = scores.unsqueeze(-1) + s_action + # [batch_size, beam_size] + scores, cands = scores.view(batch_size, -1).topk(beam_size, -1) + # [batch_size * beam_size] + scores = scores.view(-1) + beams = cands.div(k_beam, rounding_mode='floor') + nodes = nodes.view(batch_size, -1).gather(-1, cands.div(k_node, rounding_mode='floor')) + indices = (batches.unsqueeze(-1) + beams).view(-1) + + #print('indices', indices) + parents = parents[indices].view(batch_size, -1).gather(-1, cands.div(k_parent, rounding_mode='floor') % k_parent) + news = news[indices].view(batch_size, -1).gather(-1, cands % k_parent) + action = torch.stack((nodes, parents, news)).view(3, -1) + tokenwise_predictions.append([t, [x[0] for x in action.tolist()]]) + spans = spans[indices] if spans is not None else None + spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t]) + #print("SPANS", spans) + mask = mask.view(batch_size, beam_size, -1)[:, 0] + # select an 1-best tree for each sentence + spans = spans[batches + scores.view(batch_size, -1).argmax(-1)] + span_mask = spans.ge(0) + span_indices = torch.where(span_mask) + span_labels = spans[span_indices] + chart_preds = [[] for _ in range(x.shape[0])] + for i, *span in zip(*[s.tolist() for s in span_indices], span_labels.tolist()): + chart_preds[i].append(span) + kk = [chart_preds + tokenwise_predictions] + return kk + + def rightmost_chain( + self, + x: torch.Tensor, + spans: torch.LongTensor, + mask: torch.BoolTensor, + t: int + ) -> torch.Tensor: + x_p, mask_p = x[:, :t], mask[:, :t] + lens = mask_p.sum(-1) + span_mask = spans[:, :-1, 1:].ge(0) + span_lens = span_mask.sum((-1, -2)) + span_indices = torch.where(span_mask) + span_labels = spans[:, :-1, 1:][span_indices] + x_span = self.label_embed(span_labels) + x_span += x[span_indices[0], span_indices[1]] + x[span_indices[0], span_indices[2]] + node_lens = lens + span_lens + adj_mask = node_lens.unsqueeze(-1).gt(x.new_tensor(range(node_lens.max()))) + x_mask = lens.unsqueeze(-1).gt(x.new_tensor(range(adj_mask.shape[-1]))) + span_mask = ~x_mask & adj_mask + # concatenate terminals and spans + x_tree = x.new_zeros(*adj_mask.shape, x.shape[-1]).masked_scatter_(x_mask.unsqueeze(-1), x_p[mask_p]) + x_tree = x_tree.masked_scatter_(span_mask.unsqueeze(-1), x_span) + adj = mask.new_zeros(*x_tree.shape[:-1], x_tree.shape[1]) + adj_spans = lens.new_tensor(range(x_tree.shape[1])).view(1, 1, -1).repeat(2, x.shape[0], 1) + adj_spans = adj_spans.masked_scatter_(span_mask.unsqueeze(0), torch.stack(span_indices[1:])) + adj_l, adj_r, adj_w = *adj_spans.unbind(), adj_spans[1] - adj_spans[0] + adj_parent = adj_l.unsqueeze(-1).ge(adj_l.unsqueeze(-2)) & adj_r.unsqueeze(-1).le(adj_r.unsqueeze(-2)) + # set the parent of root as itself + adj_parent.diagonal(0, 1, 2).copy_(adj_w.eq(t - 1)) + adj_parent = adj_parent & span_mask.unsqueeze(1) + # closet ancestor spans as parents + adj_parent = (adj_w.unsqueeze(-2) - adj_w.unsqueeze(-1)).masked_fill_(~adj_parent, t).argmin(-1) + adj.scatter_(-1, adj_parent.unsqueeze(-1), 1) + adj = (adj | adj.transpose(-1, -2)).float() + x_tree = self.gnn_layers(x_tree, adj, adj_mask) + span_mask = span_mask.masked_scatter(span_mask, span_indices[2].eq(t - 1)) + span_lens = span_mask.sum(-1) + x_tree, span_mask = x_tree[span_mask], span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + x_span = x.new_zeros(*span_mask.shape, x.shape[-1]).masked_scatter_(span_mask.unsqueeze(-1), x_tree) + return x_span + + + def pos_loss(self, pos_logits: torch.Tensor, pos_tags: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor: + """ + Args: + pos_logits (~torch.Tensor): [batch_size, seq_len, n_tags]. + pos_tags (~torch.LongTensor): [batch_size, seq_len]. + mask (~torch.BoolTensor): [batch_size, seq_len]. + + Returns: + torch.Tensor: The POS tagging loss. + """ + loss_fn = nn.CrossEntropyLoss() + return loss_fn(pos_logits[mask], pos_tags[mask]) + + def decode_pos(self, s_tag: torch.Tensor): + """ + Decode the most likely POS tags. + + Args: + pos_logits (~torch.Tensor): [batch_size, seq_len, n_tags] + mask (~torch.BoolTensor): [batch_size, seq_len] + + Returns: + List[List[int]]: POS tags per token for each sentence in the batch. + """ + pos_preds = pos_logits.argmax(-1) + #return [seq[mask[i]].tolist() for i, seq in enumerate(pos_preds)] + return pos_preds diff --git a/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/parser-checkpoint.py b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/parser-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..543e8284b90805f045e1d5de2b15aa7c75ffc417 --- /dev/null +++ b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/parser-checkpoint.py @@ -0,0 +1,534 @@ +# -*- coding: utf-8 -*- + +import os +from typing import Dict, Iterable, Set, Union + +import torch + +from supar.models.const.aj.model import AttachJuxtaposeConstituencyModel, AttachJuxtaposeConstituencyModelPos +from supar.models.const.aj.transform import AttachJuxtaposeTree +from supar.parser import Parser +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, EOS, NUL, PAD, UNK +from supar.utils.field import Field, RawField, SubwordField +from supar.utils.logging import get_logger +from supar.utils.metric import SpanMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.utils.transform import Batch +from torch.nn.utils.rnn import pad_sequence + + +logger = get_logger(__name__) + + +def compute_pos_accuracy(pos_gold, pos_preds): + correct = 0 + total = 0 + for gold_seq, pred_seq in zip(pos_gold, pos_preds): + for g, p in zip(gold_seq, pred_seq): + if g == p: + correct += 1 + total += len(gold_seq) + return correct, total + + + + +class AttachJuxtaposeConstituencyParser(Parser): + r""" + The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`. + """ + + NAME = 'attach-juxtapose-constituency' + MODEL = AttachJuxtaposeConstituencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.TREE = self.transform.TREE + self.NODE = self.transform.NODE + self.PARENT = self.transform.PARENT + self.NEW = self.transform.NEW + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + print("here") + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + verbose: bool = True, + **kwargs + ): + + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + #print("TRAIN STEP") + words, *feats, trees, nodes, parents, news = batch + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + #print("s_tag", s_tag) + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> SpanMetric: + #print("EVAL STEP") + words, *feats, trees, nodes, parents, news = batch + #print("WORDS", words.shape, words) + mask = batch.mask[:, (2+self.args.delay):] + x, qloss = self.model(words, feats) + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + #print("CHART PREDS") + #print("self new vocab", self.NEW.vocab.items()) + #print() + preds = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL}) + for tree, chart in zip(trees, chart_preds)] + + for tree, chart in zip(trees, chart_preds): + print(tree, chart) + + print() + for tree in trees: + print("ORIG TREE", tree) + print() + for tree in preds: + print("PRED TREE", tree) + + print("=========================") + + return SpanMetric(loss, + [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds], + [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees]) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, *feats, trees = batch + mask = batch.mask[:, (2+self.args.delay):] + x, _ = self.model(words, feats) + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + chart_preds = chart_preds[0] + batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL}) + for tree, chart in zip(trees, chart_preds)] + if self.args.prob: + raise NotImplementedError("Returning action probs are currently not supported yet.") + + new_tokenwise_preds = [] + for k, y in chart_preds[1:]: + new_tokenwise_preds.append([k, y[0], self.NEW.vocab[y[1]], self.NEW.vocab[y[2]]]) + + chart_preds = [[[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0]]] + new_tokenwise_preds + #for x in chart_preds[1:]: + # new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]] + + return chart_preds #batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + r""" + Build a brand-new Parser, including initialization of all data fields and model parameters. + + Args: + path (str): + The path of the model to be saved. + min_freq (str): + The minimum frequency needed to include a token in the vocabulary. Default: 2. + fix_len (int): + The max length of all subword pieces. The excess part of each piece will be truncated. + Required if using CharLSTM/BERT. + Default: 20. + kwargs (Dict): + A dict holding the unconsumed arguments. + """ + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + TAG, CHAR = None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + pad_token = t.pad if t.pad else PAD + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay) + TAG = Field('tags', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay) + TREE = RawField('trees') + NODE, PARENT, NEW = Field('node', use_vocab=False), Field('parent', unk=UNK), Field('new', unk=UNK) + transform = AttachJuxtaposeTree(WORD=(WORD, CHAR), POS=TAG, TREE=TREE, NODE=NODE, PARENT=PARENT, NEW=NEW) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if CHAR is not None: + CHAR.build(train) + TAG.build(train) + PARENT, NEW = PARENT.build(train), NEW.build(train) + PARENT.vocab = NEW.vocab.update(PARENT.vocab) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_labels': len(NEW.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index, + 'eos_index': WORD.eos_index, + 'nul_index': NEW.vocab[NUL] + }) + logger.info(f"{transform}") + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser + +def flatten(x): + result = [] + for el in x: + if isinstance(el, list): + result.extend(flatten(el)) + else: + result.append(el) + return result + + + + +class AttachJuxtaposeConstituencyParserPos(Parser): + r""" + The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`. + """ + + NAME = 'attach-juxtapose-constituency' + MODEL = AttachJuxtaposeConstituencyModelPos + + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.TREE = self.transform.TREE + self.NODE = self.transform.NODE + self.PARENT = self.transform.PARENT + self.NEW = self.transform.NEW + self.TAG = self.transform.POS + + #print(self.TAG) + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + verbose: bool = True, + **kwargs + ): + + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + #print("TRAIN STEP") + words, *feats, postags, nodes, parents, news = batch + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + tags = feats[-1] + #loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + pos_logits + #loss = loss.mean() + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tags[:, 1:-1]) + qloss + + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> SpanMetric: + + #print("EVAL STEP") + words, *feats, trees, nodes, parents, news = batch + + tags = feats[-1] + #loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + pos_logits + #loss = loss.mean() + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tags[:, 1:-1]) + qloss + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + #print('CHART PREDS', chart_preds[0]) + + pos_preds = [[self.TAG.vocab[p.item()] for p in sent] for sent in s_tag.argmax(-1)] + trimmed_tags = [tag_seq[1:-1] for tag_seq in tags] + trimmed_tags_list = [x.cpu().tolist() for x in trimmed_tags] + pos_gold = [[self.TAG.vocab[p] for p in sent] for sent in trimmed_tags_list] + + pos_correct, pos_total = compute_pos_accuracy(pos_gold, pos_preds) + pos_acc = pos_correct/pos_total + print(pos_acc) + + """ + tags = [*feats[-1:]] + trimmed_tags = [tag_seq[1:-1] for tag_seq in tags[0]] + gtags = [*feats[-1:]][0].cpu().tolist() + gtags = [x[1:-1] for x in gtags] + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + + tag_tensor = pad_sequence(trimmed_tags, batch_first=True, padding_value=self.TAG.pad_index).to(s_tag.device) + + assert s_tag.shape[:2] == tag_tensor.shape[:2] + print("s_tag shape", s_tag.shape[:2], "tag_tensor shape", tag_tensor.shape[:2]) + + + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tag_tensor) + qloss + + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + #pos_preds = [self.TAG.vocab[x] for x in s_tag.argmax(-1)] + + for tag_seq in gtags: + for pos in tag_seq: + assert 0 <= pos < len(self.TAG.vocab) + + pos_preds = [[self.TAG.vocab[p.item()] for p in sent] for sent in s_tag.argmax(-1)] + pos_gold = [[self.TAG.vocab[pos] for pos in tag_seq] for tag_seq in gtags] + + #print("357!", len(pos_gold), len(pos_preds)) + pos_correct, pos_total = compute_pos_accuracy(pos_gold, pos_preds) + print("358!", pos_correct, pos_total, pos_correct/pos_total) + + + #print('%%% POS preds', pos_preds) + #print('%%%% POS gold', pos_gold) + """ + + #batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL}) + # for tree, chart in zip(trees, chart_preds)] + #print(batch.trees) + + preds = [AttachJuxtaposeTree.build(tree, [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart], {UNK, NUL}) + for tree, chart in zip(trees, chart_preds[0])] + #print('POS preds', s_tag.argmax(-1)) + #print("PREDS", len(preds)) + + + """ + span_preds = [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0][0]] + #print("new SPAN preds", span_preds) + verbalized_preds = [] + + + for x in chart_preds[1:]: + new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]] + #print(new_item) + verbalized_preds.append(new_item) + """ + + print("424", loss, pos_acc) + return SpanMetric(loss, + [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds], + [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees], pos_acc) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, *feats, trees = batch + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + #print('VHART PREDS: ', chart_preds) + #print('LEN VHART PREDS: ', len(chart_preds)) + + chart_preds = chart_preds[0] + batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL}) + for tree, chart in zip(trees, chart_preds)] + if self.args.prob: + raise NotImplementedError("Returning action probs are currently not supported yet.") + + new_tokenwise_preds = [] + #for pred in chart_preds: + # print("PRED", pred) + # new_tokenwise_preds.append([k, y[0], self.NEW.vocab[y[1]], self.NEW.vocab[y[2]]]) + + + span_preds = [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0]] + new_tokenwise_preds + #print("new chart preds", span_preds) + verbalized_preds = [] + for x in chart_preds[1:]: + new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]] + #print(new_item) + verbalized_preds.append(new_item) + #print('POS preds', s_tag.argmax(-1)) + pos_preds = [self.TAG.vocab[x] for x in s_tag.argmax(-1)] + return span_preds, verbalized_preds, pos_preds #batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + r""" + Build a brand-new Parser, including initialization of all data fields and model parameters. + + Args: + path (str): + The path of the model to be saved. + min_freq (str): + The minimum frequency needed to include a token in the vocabulary. Default: 2. + fix_len (int): + The max length of all subword pieces. The excess part of each piece will be truncated. + Required if using CharLSTM/BERT. + Default: 20. + kwargs (Dict): + A dict holding the unconsumed arguments. + """ + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + TAG, CHAR = None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + pad_token = t.pad if t.pad else PAD + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay) + TAG = Field('tags', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay) + TREE = RawField('trees') + NODE, PARENT, NEW = Field('node', use_vocab=False), Field('parent', unk=UNK), Field('new', unk=UNK) + transform = AttachJuxtaposeTree(WORD=(WORD, CHAR), POS=TAG, TREE=TREE, NODE=NODE, PARENT=PARENT, NEW=NEW) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if CHAR is not None: + CHAR.build(train) + TAG.build(train) + #print("203 TAG VOCAB", [x for x in TAG.vocab.items()]) + PARENT, NEW = PARENT.build(train), NEW.build(train) + PARENT.vocab = NEW.vocab.update(PARENT.vocab) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_labels': len(NEW.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index, + 'eos_index': WORD.eos_index, + 'nul_index': NEW.vocab[NUL] + }) + logger.info(f"{transform}") + + #print('TAG VOCAB', TAG.vocab.items()) + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser \ No newline at end of file diff --git a/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/transform-checkpoint.py b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/transform-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..56e0f5158b6b3b214b7e99e8a479fbabc56bcbf6 --- /dev/null +++ b/tania_scripts/supar/models/const/aj/.ipynb_checkpoints/transform-checkpoint.py @@ -0,0 +1,459 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union + +import nltk +import torch + +from supar.models.const.crf.transform import Tree +from supar.utils.common import NUL +from supar.utils.logging import get_logger +from supar.utils.tokenizer import Tokenizer +from supar.utils.transform import Sentence + +if TYPE_CHECKING: + from supar.utils import Field + +logger = get_logger(__name__) + + +class AttachJuxtaposeTree(Tree): + r""" + :class:`AttachJuxtaposeTree` is derived from the :class:`Tree` class, + supporting back-and-forth transformations between trees and AttachJuxtapose actions :cite:`yang-deng-2020-aj`. + + Attributes: + WORD: + Words in the sentence. + POS: + Part-of-speech tags, or underscores if not available. + TREE: + The raw constituency tree in :class:`nltk.tree.Tree` format. + NODE: + The target node on each rightmost chain. + PARENT: + The label of the parent node of each terminal. + NEW: + The label of each newly inserted non-terminal with a target node and a terminal as juxtaposed children. + ``NUL`` represents the `Attach` action. + """ + + fields = ['WORD', 'POS', 'TREE', 'NODE', 'PARENT', 'NEW'] + + def __init__( + self, + WORD: Optional[Union[Field, Iterable[Field]]] = None, + POS: Optional[Union[Field, Iterable[Field]]] = None, + TREE: Optional[Union[Field, Iterable[Field]]] = None, + NODE: Optional[Union[Field, Iterable[Field]]] = None, + PARENT: Optional[Union[Field, Iterable[Field]]] = None, + NEW: Optional[Union[Field, Iterable[Field]]] = None + ) -> Tree: + super().__init__() + + self.WORD = WORD + self.POS = POS + self.TREE = TREE + self.NODE = NODE + self.PARENT = PARENT + self.NEW = NEW + + @property + def src(self): + return self.WORD, self.POS, self.TREE + + @property + def tgt(self): + return self.NODE, self.PARENT, self.NEW + + @classmethod + def tree2action(cls, tree: nltk.Tree): + r""" + Converts a constituency tree into AttachJuxtapose actions. + + Args: + tree (nltk.tree.Tree): + A constituency tree in :class:`nltk.tree.Tree` format. + + Returns: + A sequence of AttachJuxtapose actions. + + Examples: + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree + >>> tree = nltk.Tree.fromstring(''' + (TOP + (S + (NP (_ Arthur)) + (VP + (_ is) + (NP (NP (_ King)) (PP (_ of) (NP (_ the) (_ Britons))))) + (_ .))) + ''') + >>> tree.pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + >>> AttachJuxtaposeTree.tree2action(tree) + [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'), + (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'), + (0, '<nul>', '<nul>')] + """ + + def isroot(node): + return node == tree[0] + + def isterminal(node): + return len(node) == 1 and not isinstance(node[0], nltk.Tree) + + def last_leaf(node): + pos = () + while True: + pos += (len(node) - 1,) + node = node[-1] + if isterminal(node): + return node, pos + + def parent(position): + return tree[position[:-1]] + + def grand(position): + return tree[position[:-2]] + + def detach(tree): + last, last_pos = last_leaf(tree) + siblings = parent(last_pos)[:-1] + + if len(siblings) > 0: + last_subtree = last + last_subtree_siblings = siblings + parent_label = NUL + else: + last_subtree, last_pos = parent(last_pos), last_pos[:-1] + last_subtree_siblings = [] if isroot(last_subtree) else parent(last_pos)[:-1] + parent_label = last_subtree.label() + + target_pos, new_label, last_tree = 0, NUL, tree + if isroot(last_subtree): + last_tree = None + + elif len(last_subtree_siblings) == 1 and not isterminal(last_subtree_siblings[0]): + new_label = parent(last_pos).label() + new_label = new_label + target = last_subtree_siblings[0] + last_grand = grand(last_pos) + if last_grand is None: + last_tree = targetistermina + else: + last_grand[-1] = target + target_pos = len(last_pos) - 2 + else: + target = parent(last_pos) + target.pop() + target_pos = len(last_pos) - 2 + action = target_pos, parent_label, new_label + return action, last_tree + if tree is None: + return [] + action, last_tree = detach(tree) + return cls.tree2action(last_tree) + [action] + + @classmethod + def action2tree( + cls, + tree: nltk.Tree, + actions: List[Tuple[int, str, str]], + join: str = '::', + ) -> nltk.Tree: + r""" + Recovers a constituency tree from a sequence of AttachJuxtapose actions. + + Args: + tree (nltk.tree.Tree): + An empty tree that provides a base for building a result tree. + actions (List[Tuple[int, str, str]]): + A sequence of AttachJuxtapose actions. + join (str): + A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains. + Default: ``'::'``. + + Returns: + A result constituency tree. + + Examples: + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree + >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP') + >>> AttachJuxtaposeTree.action2tree(tree, + [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'), + (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'), + (0, '<nul>', '<nul>')]).pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + """ + + def target(node, depth): + node_pos = () + for _ in range(depth): + node_pos += (len(node) - 1,) + node = node[-1] + return node, node_pos + + def parent(tree, position): + return tree[position[:-1]] + + def execute(tree: nltk.Tree, terminal: Tuple(str, str), action: Tuple[int, str, str]) -> nltk.Tree: + target_pos, parent_label, new_label, post = action + #print(target_pos, parent_label, new_label) + new_leaf = nltk.Tree(post, [terminal[0]]) + + # create the subtree to be inserted + new_subtree = new_leaf if parent_label == NUL else nltk.Tree(parent_label, [new_leaf]) + # find the target position at which to insert the new subtree + target_node = tree + if target_node is not None: + target_node, target_pos = target(target_node, target_pos) + + # Attach + if new_label == NUL: + # attach the first token + if target_node is None: + return new_subtree + target_node.append(new_subtree) + # Juxtapose + else: + new_subtree = nltk.Tree(new_label, [target_node, new_subtree]) + if len(target_pos) > 0: + parent_node = parent(tree, target_pos) + parent_node[-1] = new_subtree + else: + tree = new_subtree + return tree + + tree, root, terminals = None, tree.label(), tree.pos() + for terminal, action in zip(terminals, actions): + tree = execute(tree, terminal, action) + # recover unary chains + nodes = [tree] + while nodes: + node = nodes.pop() + if isinstance(node, nltk.Tree): + nodes.extend(node) + if join in node.label(): + labels = node.label().split(join) + node.set_label(labels[0]) + subtree = nltk.Tree(labels[-1], node) + for label in reversed(labels[1:-1]): + subtree = nltk.Tree(label, [subtree]) + node[:] = [subtree] + return nltk.Tree(root, [tree]) + + @classmethod + def action2span( + cls, + action: torch.Tensor, + spans: torch.Tensor = None, + nul_index: int = -1, + mask: torch.BoolTensor = None + ) -> torch.Tensor: + r""" + Converts a batch of the tensorized action at a given step into spans. + + Args: + action (~torch.Tensor): ``[3, batch_size]``. + A batch of the tensorized action at a given step, containing indices of target nodes, parent and new labels. + spans (~torch.Tensor): + Spans generated at previous steps, ``None`` at the first step. Default: ``None``. + nul_index (int): + The index for the obj:`NUL` token, representing the Attach action. Default: -1. + mask (~torch.BoolTensor): ``[batch_size]``. + The mask for covering the unpadded tokens. + + Returns: + A tensor representing a batch of spans for the given step. + + Examples: + >>> from collections import Counter + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree, Vocab + >>> from supar.utils.common import NUL + >>> nodes, parents, news = zip(*[(0, 'NP', NUL), (0, 'VP', 'S'), (1, 'NP', NUL), + (2, 'PP', 'NP'), (3, 'NP', NUL), (4, NUL, NUL), + (0, NUL, NUL)]) + >>> vocab = Vocab(Counter(sorted(set([*parents, *news])))) + >>> actions = torch.tensor([nodes, vocab[parents], vocab[news]]).unsqueeze(1) + >>> spans = None + >>> for action in actions.unbind(-1): + ... spans = AttachJuxtaposeTree.action2span(action, spans, vocab[NUL]) + ... + >>> spans + tensor([[[-1, 1, -1, -1, -1, -1, -1, 3], + [-1, -1, -1, -1, -1, -1, 4, -1], + [-1, -1, -1, 1, -1, -1, 1, -1], + [-1, -1, -1, -1, -1, -1, 2, -1], + [-1, -1, -1, -1, -1, -1, 1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1]]]) + >>> sequence = torch.where(spans.ge(0)) + >>> sequence = list(zip(sequence[1].tolist(), sequence[2].tolist(), vocab[spans[sequence]])) + >>> sequence + [(0, 1, 'NP'), (0, 7, 'S'), (1, 6, 'VP'), (2, 3, 'NP'), (2, 6, 'NP'), (3, 6, 'PP'), (4, 6, 'NP')] + >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP') + >>> AttachJuxtaposeTree.build(tree, sequence).pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + + """ + + # [batch_size] + target, parent, new = action + if spans is None: + spans = action.new_full((action.shape[1], 2, 2), -1) + spans[:, 0, 1] = parent + return spans + if mask is None: + mask = torch.ones_like(target, dtype=bool) + juxtapose_mask = new.ne(nul_index) & mask + # ancestor nodes are those on the rightmost chain and higher than the target node + # [batch_size, seq_len] + rightmost_mask = spans[..., -1].ge(0) + ancestors = rightmost_mask.cumsum(-1).masked_fill_(~rightmost_mask, -1) - 1 + # should not include the target node for the Juxtapose action + ancestor_mask = mask.unsqueeze(-1) & ancestors.ge(0) & ancestors.le((target - juxtapose_mask.long()).unsqueeze(-1)) + target_pos = torch.where(ancestors.eq(target.unsqueeze(-1))[juxtapose_mask])[-1] + # the right boundaries of ancestor nodes should be aligned with the new generated terminals + spans = torch.cat((spans, torch.where(ancestor_mask, spans[..., -1], -1).unsqueeze(-1)), -1) + spans[..., -2].masked_fill_(ancestor_mask, -1) + spans[juxtapose_mask, target_pos, -1] = new.masked_fill(new.eq(nul_index), -1)[juxtapose_mask] + spans[mask, -1, -1] = parent.masked_fill(parent.eq(nul_index), -1)[mask] + # [batch_size, seq_len+1, seq_len+1] + spans = torch.cat((spans, torch.full_like(spans[:, :1], -1)), 1) + return spans + + def load( + self, + data: Union[str, Iterable], + lang: Optional[str] = None, + **kwargs + ) -> List[AttachJuxtaposeTreeSentence]: + r""" + Args: + data (Union[str, Iterable]): + A filename or a list of instances. + lang (str): + Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. + ``None`` if tokenization is not required. + Default: ``None``. + + Returns: + A list of :class:`AttachJuxtaposeTreeSentence` instances. + """ + + if lang is not None: + tokenizer = Tokenizer(lang) + if isinstance(data, str) and os.path.exists(data): + if data.endswith('.txt'): + data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1) + else: + data = open(data) + else: + if lang is not None: + data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)] + else: + data = [data] if isinstance(data[0], str) else data + + index = 0 + for s in data: + + try: + tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root) + sentence = AttachJuxtaposeTreeSentence(self, tree, index) + except ValueError: + logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!") + continue + except IndexError: + tree = nltk.Tree.fromstring('(S ' + s + ')') + sentence = AttachJuxtaposeTreeSentence(self, tree, index) + else: + yield sentence + index += 1 + self.root = tree.label() + + +class AttachJuxtaposeTreeSentence(Sentence): + r""" + Args: + transform (AttachJuxtaposeTree): + A :class:`AttachJuxtaposeTree` object. + tree (nltk.tree.Tree): + A :class:`nltk.tree.Tree` object. + index (Optional[int]): + Index of the sentence in the corpus. Default: ``None``. + """ + + def __init__( + self, + transform: AttachJuxtaposeTree, + tree: nltk.Tree, + index: Optional[int] = None + ) -> AttachJuxtaposeTreeSentence: + super().__init__(transform, index) + + words, tags = zip(*tree.pos()) + nodes, parents, news = None, None, None + if transform.training: + oracle_tree = tree.copy(True) + # the root node must have a unary chain + if len(oracle_tree) > 1: + oracle_tree[:] = [nltk.Tree('*', oracle_tree)] + oracle_tree.collapse_unary(joinChar='::') + if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree): + oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]]) + nodes, parents, news = zip(*transform.tree2action(oracle_tree)) + tags = [x.split("##")[0] for x in tags] + self.values = [words, tags, tree, nodes, parents, news] + + def __repr__(self): + return self.values[-4].pformat(1000000) + + def pretty_print(self): + self.values[-4].pretty_print() diff --git a/tania_scripts/supar/models/const/aj/__init__.py b/tania_scripts/supar/models/const/aj/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1fe4855b790afea3b677caaaa7de5f6ac3549216 --- /dev/null +++ b/tania_scripts/supar/models/const/aj/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import AttachJuxtaposeConstituencyModel, AttachJuxtaposeConstituencyModelPos +from .parser import AttachJuxtaposeConstituencyParser, AttachJuxtaposeConstituencyParserPos + +__all__ = ['AttachJuxtaposeConstituencyModel', 'AttachJuxtaposeConstituencyModelPos', 'AttachJuxtaposeConstituencyParser', 'AttachJuxtaposeConstituencyParserPos'] diff --git a/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf09a38138e561426e88c74ad582913e27d79d0e Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bd79cb4caac4fa72e4cd50659aee9fc13dcc3d4 Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-39.pyc b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb0340d7a5173f1782e452ef95c2faf9025e674d Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/__init__.cpython-39.pyc differ diff --git a/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4da15ae58bc418dda6f19b8e44874d984ba30fda Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a76052d5e93b7b2ce0a33eae863662d4a4a1438 Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-39.pyc b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d40ac4f8b95a668781a063d3ee4aa591fe26927 Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/model.cpython-39.pyc differ diff --git a/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..080b0dba88ea7c05edcde2a12f2d5e17b2756b70 Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b58212f55501521f49a41f39da967e6efed3566 Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..731814f0b914ad90ae5e604f6a0883b85690c09e Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aba0194d23a0b3d6f8024b35018c2b00be0c9752 Binary files /dev/null and b/tania_scripts/supar/models/const/aj/__pycache__/transform.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/aj/model.py b/tania_scripts/supar/models/const/aj/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8f4a8ce0b5ab9a6a35e8adb18eecb6b8cacecd81 --- /dev/null +++ b/tania_scripts/supar/models/const/aj/model.py @@ -0,0 +1,786 @@ +# -*- coding: utf-8 -*- + +from typing import List, Tuple + +import torch +import torch.nn as nn +from supar.model import Model +from supar.models.const.aj.transform import AttachJuxtaposeTree +from supar.modules import GraphConvolutionalNetwork, MLP, DecoderLSTM +from supar.utils import Config +from supar.utils.common import INF +from supar.utils.fn import pad + +class DecoderLSTMPos(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, device): + super().__init__() + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, + batch_first=True, dropout=dropout) + self.classifier = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + # x: [batch_size, seq_len, input_dim] + output, _ = self.lstm(x) + logits = self.classifier(output) + return logits + +class AttachJuxtaposeConstituencyModel(Model): + r""" + The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`. + + Args: + n_words (int): + The size of the word vocabulary. + n_labels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layers. Default: .33. + n_gnn_layers (int): + The number of GNN layers. Default: 3. + gnn_dropout (float): + The dropout ratio of GNN layers. Default: .33. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_labels, + n_tags=None, + n_chars=None, + encoder='lstm', + feat=['char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, True), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_gnn_layers=3, + gnn_dropout=.33, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + # the last one represents the dummy node in the initial states + self.label_embed = nn.Embedding(n_labels+1, self.args.n_encoder_hidden) + self.gnn_layers = GraphConvolutionalNetwork(n_model=self.args.n_encoder_hidden, + n_layers=self.args.n_gnn_layers, + dropout=self.args.gnn_dropout) + + self.node_classifier = nn.Sequential( + nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2), + nn.LayerNorm(self.args.n_encoder_hidden // 2), + nn.ReLU(), + nn.Linear(self.args.n_encoder_hidden // 2, 1), + ) + self.label_classifier = nn.Sequential( + nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2), + nn.LayerNorm(self.args.n_encoder_hidden // 2), + nn.ReLU(), + nn.Linear(self.args.n_encoder_hidden // 2, 2 * n_labels), + ) + + # create delay projection + if self.args.delay != 0: + self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay+1), + n_out=self.args.n_encoder_hidden, dropout=gnn_dropout) + + self.criterion = nn.CrossEntropyLoss() + + def forward( + self, + words: torch.LongTensor, + feats: List[torch.LongTensor] + ) -> Tuple[torch.Tensor]: + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor: + Contextualized output hidden states of shape ``[batch_size, seq_len, n_model]`` of the input. + """ + x = self.encode(words, feats) + + # adjust lengths to allow delay predictions + if self.args.delay != 0: + x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i)] for i in range(self.args.delay + 1)], dim=2) + x = self.delay_proj(x) + + # pass through vector quantization + x, qloss = self.vq_forward(x) + + return x, qloss + + def loss( + self, + x: torch.Tensor, + nodes: torch.LongTensor, + parents: torch.LongTensor, + news: torch.LongTensor, + mask: torch.BoolTensor + ) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, n_model]``. + Contextualized output hidden states. + nodes (~torch.LongTensor): ``[batch_size, seq_len]``. + The target node positions on rightmost chains. + parents (~torch.LongTensor): ``[batch_size, seq_len]``. + The parent node labels of terminals. + news (~torch.LongTensor): ``[batch_size, seq_len]``. + The parent node labels of juxtaposed targets and terminals. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + + Returns: + ~torch.Tensor: + The training loss. + """ + + spans, s_node, x_node = None, [], [] + actions = torch.stack((nodes, parents, news)) + for t, action in enumerate(actions.unbind(-1)): + if t == 0: + x_span = self.label_embed(actions.new_full((x.shape[0], 1), self.args.n_labels)) + span_mask = mask[:, :1] + else: + x_span = self.rightmost_chain(x, spans, mask, t) + span_lens = spans[:, :-1, -1].ge(0).sum(-1) + span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + x_rightmost = torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1) + s_node.append(self.node_classifier(x_rightmost).squeeze(-1)) + # we found softmax is slightly better than sigmoid in the original paper + s_node[-1] = s_node[-1].masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0) + x_node.append(torch.bmm(s_node[-1].softmax(-1).unsqueeze(1), x_span).squeeze(1)) + spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t]) + attach_mask = x.new_tensor(range(self.args.n_labels)).eq(self.args.nul_index) + s_node, x_node = pad(s_node, -INF).transpose(0, 1), torch.stack(x_node, 1) + s_parent, s_new = self.label_classifier(torch.cat((x, x_node), -1)).chunk(2, -1) + s_parent = torch.cat((s_parent[:, :1].masked_fill(attach_mask, -INF), s_parent[:, 1:]), 1) + s_new = torch.cat((s_new[:, :1].masked_fill(~attach_mask, -INF), s_new[:, 1:]), 1) + node_loss = self.criterion(s_node[mask], nodes[mask]) + label_loss = self.criterion(s_parent[mask], parents[mask]) + self.criterion(s_new[mask], news[mask]) + return node_loss + label_loss + + def decode( + self, + x: torch.Tensor, + mask: torch.BoolTensor, + beam_size: int = 1 + ) -> List[List[Tuple]]: + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, n_model]``. + Contextualized output hidden states. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + beam_size (int): + Beam size for decoding. Default: 1. + + Returns: + List[List[Tuple]]: + Sequences of factorized labeled trees. + """ + tokenwise_predictions = [] + spans = None + batch_size, *_ = x.shape + n_labels = self.args.n_labels + # [batch_size * beam_size, ...] + x = x.unsqueeze(1).repeat(1, beam_size, 1, 1).view(-1, *x.shape[1:]) + mask = mask.unsqueeze(1).repeat(1, beam_size, 1).view(-1, *mask.shape[1:]) + # [batch_size] + batches = x.new_tensor(range(batch_size)).long() * beam_size + # accumulated scores + scores = x.new_full((batch_size, beam_size), -INF).index_fill_(-1, x.new_tensor(0).long(), 0).view(-1) + for t in range(x.shape[1]): + if t == 0: + x_span = self.label_embed(batches.new_full((x.shape[0], 1), n_labels)) + span_mask = mask[:, :1] + else: + x_span = self.rightmost_chain(x, spans, mask, t) + span_lens = spans[:, :-1, -1].ge(0).sum(-1) + span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + s_node = self.node_classifier(torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)).squeeze(-1) + s_node = s_node.masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0).log_softmax(-1) + # we found softmax is slightly better than sigmoid in the original paper + x_node = torch.bmm(s_node.exp().unsqueeze(1), x_span).squeeze(1) + s_parent, s_new = self.label_classifier(torch.cat((x[:, t], x_node), -1)).chunk(2, -1) + s_parent, s_new = s_parent.log_softmax(-1), s_new.log_softmax(-1) + if t == 0: + s_parent[:, self.args.nul_index] = -INF + s_new[:, s_new.new_tensor(range(self.args.n_labels)).ne(self.args.nul_index)] = -INF + s_node, nodes = s_node.topk(min(s_node.shape[-1], beam_size), -1) + s_parent, parents = s_parent.topk(min(n_labels, beam_size), -1) + s_new, news = s_new.topk(min(n_labels, beam_size), -1) + s_action = s_node.unsqueeze(2) + (s_parent.unsqueeze(2) + s_new.unsqueeze(1)).view(x.shape[0], 1, -1) + s_action = s_action.view(x.shape[0], -1) + k_beam, k_node, k_parent = s_action.shape[-1], parents.shape[-1] * news.shape[-1], news.shape[-1] + # [batch_size * beam_size, k_beam] + scores = scores.unsqueeze(-1) + s_action + # [batch_size, beam_size] + scores, cands = scores.view(batch_size, -1).topk(beam_size, -1) + # [batch_size * beam_size] + scores = scores.view(-1) + beams = cands.div(k_beam, rounding_mode='floor') + nodes = nodes.view(batch_size, -1).gather(-1, cands.div(k_node, rounding_mode='floor')) + indices = (batches.unsqueeze(-1) + beams).view(-1) + + #print('indices', indices) + parents = parents[indices].view(batch_size, -1).gather(-1, cands.div(k_parent, rounding_mode='floor') % k_parent) + news = news[indices].view(batch_size, -1).gather(-1, cands % k_parent) + action = torch.stack((nodes, parents, news)).view(3, -1) + tokenwise_predictions.append([t, [x[0] for x in action.tolist()]]) + spans = spans[indices] if spans is not None else None + spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t]) + #print("SPANS", spans) + mask = mask.view(batch_size, beam_size, -1)[:, 0] + # select an 1-best tree for each sentence + spans = spans[batches + scores.view(batch_size, -1).argmax(-1)] + span_mask = spans.ge(0) + span_indices = torch.where(span_mask) + span_labels = spans[span_indices] + chart_preds = [[] for _ in range(x.shape[0])] + for i, *span in zip(*[s.tolist() for s in span_indices], span_labels.tolist()): + chart_preds[i].append(span) + kk = [chart_preds + tokenwise_predictions] + return kk + + def rightmost_chain( + self, + x: torch.Tensor, + spans: torch.LongTensor, + mask: torch.BoolTensor, + t: int + ) -> torch.Tensor: + x_p, mask_p = x[:, :t], mask[:, :t] + lens = mask_p.sum(-1) + span_mask = spans[:, :-1, 1:].ge(0) + span_lens = span_mask.sum((-1, -2)) + span_indices = torch.where(span_mask) + span_labels = spans[:, :-1, 1:][span_indices] + x_span = self.label_embed(span_labels) + x_span += x[span_indices[0], span_indices[1]] + x[span_indices[0], span_indices[2]] + node_lens = lens + span_lens + adj_mask = node_lens.unsqueeze(-1).gt(x.new_tensor(range(node_lens.max()))) + x_mask = lens.unsqueeze(-1).gt(x.new_tensor(range(adj_mask.shape[-1]))) + span_mask = ~x_mask & adj_mask + # concatenate terminals and spans + x_tree = x.new_zeros(*adj_mask.shape, x.shape[-1]).masked_scatter_(x_mask.unsqueeze(-1), x_p[mask_p]) + x_tree = x_tree.masked_scatter_(span_mask.unsqueeze(-1), x_span) + adj = mask.new_zeros(*x_tree.shape[:-1], x_tree.shape[1]) + adj_spans = lens.new_tensor(range(x_tree.shape[1])).view(1, 1, -1).repeat(2, x.shape[0], 1) + adj_spans = adj_spans.masked_scatter_(span_mask.unsqueeze(0), torch.stack(span_indices[1:])) + adj_l, adj_r, adj_w = *adj_spans.unbind(), adj_spans[1] - adj_spans[0] + adj_parent = adj_l.unsqueeze(-1).ge(adj_l.unsqueeze(-2)) & adj_r.unsqueeze(-1).le(adj_r.unsqueeze(-2)) + # set the parent of root as itself + adj_parent.diagonal(0, 1, 2).copy_(adj_w.eq(t - 1)) + adj_parent = adj_parent & span_mask.unsqueeze(1) + # closet ancestor spans as parents + adj_parent = (adj_w.unsqueeze(-2) - adj_w.unsqueeze(-1)).masked_fill_(~adj_parent, t).argmin(-1) + adj.scatter_(-1, adj_parent.unsqueeze(-1), 1) + adj = (adj | adj.transpose(-1, -2)).float() + x_tree = self.gnn_layers(x_tree, adj, adj_mask) + span_mask = span_mask.masked_scatter(span_mask, span_indices[2].eq(t - 1)) + span_lens = span_mask.sum(-1) + x_tree, span_mask = x_tree[span_mask], span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + x_span = x.new_zeros(*span_mask.shape, x.shape[-1]).masked_scatter_(span_mask.unsqueeze(-1), x_tree) + return x_span + + +class AttachJuxtaposeConstituencyModelPos(Model): + r""" + The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`. + + Args: + n_words (int): + The size of the word vocabulary. + n_labels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layers. Default: .33. + n_gnn_layers (int): + The number of GNN layers. Default: 3. + gnn_dropout (float): + The dropout ratio of GNN layers. Default: .33. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_labels, + n_tags=16, + n_chars=None, + encoder='lstm', + feat=['char', 'tag'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, True), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_gnn_layers=3, + gnn_dropout=.33, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + # the last one represents the dummy node in the initial states + self.label_embed = nn.Embedding(n_labels+1, self.args.n_encoder_hidden) + self.gnn_layers = GraphConvolutionalNetwork(n_model=self.args.n_encoder_hidden, + n_layers=self.args.n_gnn_layers, + dropout=self.args.gnn_dropout) + + self.node_classifier = nn.Sequential( + nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2), + nn.LayerNorm(self.args.n_encoder_hidden // 2), + nn.ReLU(), + nn.Linear(self.args.n_encoder_hidden // 2, 1), + ) + self.label_classifier = nn.Sequential( + nn.Linear(2 * self.args.n_encoder_hidden, self.args.n_encoder_hidden // 2), + nn.LayerNorm(self.args.n_encoder_hidden // 2), + nn.ReLU(), + nn.Linear(self.args.n_encoder_hidden // 2, 2 * n_labels), + ) + self.pos_classifier = DecoderLSTMPos( + self.args.n_encoder_hidden, self.args.n_encoder_hidden, self.args.n_tags, + num_layers=1, dropout=encoder_dropout, device=self.device + ) + + #self.pos_tagger = nn.Identity() + # create delay projection + if self.args.delay != 0: + self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay+1), + n_out=self.args.n_encoder_hidden, dropout=gnn_dropout) + + self.criterion = nn.CrossEntropyLoss() + + def encoder_forward(self, words: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Applies encoding forward pass. Maps a tensor of word indices (`words`) to their corresponding neural + representation. + Args: + words: torch.IntTensor ~ [batch_size, bos + pad(seq_len) + eos + delay] + feats: List[torch.Tensor] + lens: List[int] + + Returns: x, qloss + x: torch.FloatTensor ~ [batch_size, bos + pad(seq_len) + eos, embed_dim] + qloss: torch.FloatTensor ~ 1 + + """ + + x = super().encode(words, feats) + s_tag = self.pos_classifier(x[:, 1:-(1+self.args.delay), :]) + + # adjust lengths to allow delay predictions + # x ~ [batch_size, bos + pad(seq_len) + eos, embed_dim] + if self.args.delay != 0: + x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i), :] for i in range(self.args.delay + 1)], dim=2) + x = self.delay_proj(x) + + # pass through vector quantization + x, qloss = self.vq_forward(x) + return x, s_tag, qloss + + + def forward( + self, + words: torch.LongTensor, + feats: List[torch.LongTensor] + ) -> Tuple[torch.Tensor]: + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor: + Contextualized output hidden states of shape ``[batch_size, seq_len, n_model]`` of the input. + """ + x, s_tag, qloss = self.encoder_forward(words, feats) + + return x, s_tag, qloss + + def loss( + self, + x: torch.Tensor, + nodes: torch.LongTensor, + parents: torch.LongTensor, + news: torch.LongTensor, + mask: torch.BoolTensor, s_tags: torch.LongTensor, tags: torch.LongTensor + ) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, n_model]``. + Contextualized output hidden states. + nodes (~torch.LongTensor): ``[batch_size, seq_len]``. + The target node positions on rightmost chains. + parents (~torch.LongTensor): ``[batch_size, seq_len]``. + The parent node labels of terminals. + news (~torch.LongTensor): ``[batch_size, seq_len]``. + The parent node labels of juxtaposed targets and terminals. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + + Returns: + ~torch.Tensor: + The training loss. + """ + + spans, s_node, x_node = None, [], [] + actions = torch.stack((nodes, parents, news)) + for t, action in enumerate(actions.unbind(-1)): + if t == 0: + x_span = self.label_embed(actions.new_full((x.shape[0], 1), self.args.n_labels)) + span_mask = mask[:, :1] + else: + x_span = self.rightmost_chain(x, spans, mask, t) + span_lens = spans[:, :-1, -1].ge(0).sum(-1) + span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + x_rightmost = torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1) + s_node.append(self.node_classifier(x_rightmost).squeeze(-1)) + # we found softmax is slightly better than sigmoid in the original paper + s_node[-1] = s_node[-1].masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0) + x_node.append(torch.bmm(s_node[-1].softmax(-1).unsqueeze(1), x_span).squeeze(1)) + spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t]) + attach_mask = x.new_tensor(range(self.args.n_labels)).eq(self.args.nul_index) + s_node, x_node = pad(s_node, -INF).transpose(0, 1), torch.stack(x_node, 1) + s_parent, s_new = self.label_classifier(torch.cat((x, x_node), -1)).chunk(2, -1) + #s_postag = self.pos_classifier(x[:, 1:-(1+self.args.delay), :]).chunk(2, -1) + s_parent = torch.cat((s_parent[:, :1].masked_fill(attach_mask, -INF), s_parent[:, 1:]), 1) + s_new = torch.cat((s_new[:, :1].masked_fill(~attach_mask, -INF), s_new[:, 1:]), 1) + node_loss = self.criterion(s_node[mask], nodes[mask]) + #print('node loss', node_loss) + label_loss = self.criterion(s_parent[mask], parents[mask]) + self.criterion(s_new[mask], news[mask]) + #print('label loss', label_loss) + + #print(s_tag[mask].shape, tags[mask].shape) + tag_loss = self.criterion(s_tags[mask], tags[mask]) + #print('tag loss', tag_loss) + #tag_loss = self.pos_loss(s_tags, tags, mask) + #print("node loss, label loss, tag loss", node_loss, label_loss, tag_loss, node_loss + label_loss + tag_loss) + return node_loss + label_loss + tag_loss + + def decode( + self, + x: torch.Tensor, + mask: torch.BoolTensor, + beam_size: int = 1 + ) -> List[List[Tuple]]: + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, n_model]``. + Contextualized output hidden states. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + beam_size (int): + Beam size for decoding. Default: 1. + + Returns: + List[List[Tuple]]: + Sequences of factorized labeled trees. + """ + tokenwise_predictions = [] + spans = None + batch_size, *_ = x.shape + n_labels = self.args.n_labels + # [batch_size * beam_size, ...] + x = x.unsqueeze(1).repeat(1, beam_size, 1, 1).view(-1, *x.shape[1:]) + mask = mask.unsqueeze(1).repeat(1, beam_size, 1).view(-1, *mask.shape[1:]) + # [batch_size] + batches = x.new_tensor(range(batch_size)).long() * beam_size + # accumulated scores + scores = x.new_full((batch_size, beam_size), -INF).index_fill_(-1, x.new_tensor(0).long(), 0).view(-1) + for t in range(x.shape[1]): + if t == 0: + x_span = self.label_embed(batches.new_full((x.shape[0], 1), n_labels)) + span_mask = mask[:, :1] + else: + x_span = self.rightmost_chain(x, spans, mask, t) + span_lens = spans[:, :-1, -1].ge(0).sum(-1) + span_mask = span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + s_node = self.node_classifier(torch.cat((x_span, x[:, t].unsqueeze(1).expand_as(x_span)), -1)).squeeze(-1) + s_node = s_node.masked_fill_(~span_mask, -INF).masked_fill(~span_mask.any(-1).unsqueeze(-1), 0).log_softmax(-1) + # we found softmax is slightly better than sigmoid in the original paper + x_node = torch.bmm(s_node.exp().unsqueeze(1), x_span).squeeze(1) + s_parent, s_new = self.label_classifier(torch.cat((x[:, t], x_node), -1)).chunk(2, -1) + s_parent, s_new = s_parent.log_softmax(-1), s_new.log_softmax(-1) + if t == 0: + s_parent[:, self.args.nul_index] = -INF + s_new[:, s_new.new_tensor(range(self.args.n_labels)).ne(self.args.nul_index)] = -INF + s_node, nodes = s_node.topk(min(s_node.shape[-1], beam_size), -1) + s_parent, parents = s_parent.topk(min(n_labels, beam_size), -1) + s_new, news = s_new.topk(min(n_labels, beam_size), -1) + s_action = s_node.unsqueeze(2) + (s_parent.unsqueeze(2) + s_new.unsqueeze(1)).view(x.shape[0], 1, -1) + s_action = s_action.view(x.shape[0], -1) + k_beam, k_node, k_parent = s_action.shape[-1], parents.shape[-1] * news.shape[-1], news.shape[-1] + # [batch_size * beam_size, k_beam] + scores = scores.unsqueeze(-1) + s_action + # [batch_size, beam_size] + scores, cands = scores.view(batch_size, -1).topk(beam_size, -1) + # [batch_size * beam_size] + scores = scores.view(-1) + beams = cands.div(k_beam, rounding_mode='floor') + nodes = nodes.view(batch_size, -1).gather(-1, cands.div(k_node, rounding_mode='floor')) + indices = (batches.unsqueeze(-1) + beams).view(-1) + + #print('indices', indices) + parents = parents[indices].view(batch_size, -1).gather(-1, cands.div(k_parent, rounding_mode='floor') % k_parent) + news = news[indices].view(batch_size, -1).gather(-1, cands % k_parent) + action = torch.stack((nodes, parents, news)).view(3, -1) + tokenwise_predictions.append([t, [x[0] for x in action.tolist()]]) + spans = spans[indices] if spans is not None else None + spans = AttachJuxtaposeTree.action2span(action, spans, self.args.nul_index, mask[:, t]) + #print("SPANS", spans) + mask = mask.view(batch_size, beam_size, -1)[:, 0] + # select an 1-best tree for each sentence + spans = spans[batches + scores.view(batch_size, -1).argmax(-1)] + span_mask = spans.ge(0) + span_indices = torch.where(span_mask) + span_labels = spans[span_indices] + chart_preds = [[] for _ in range(x.shape[0])] + for i, *span in zip(*[s.tolist() for s in span_indices], span_labels.tolist()): + chart_preds[i].append(span) + kk = [chart_preds + tokenwise_predictions] + return kk + + def rightmost_chain( + self, + x: torch.Tensor, + spans: torch.LongTensor, + mask: torch.BoolTensor, + t: int + ) -> torch.Tensor: + x_p, mask_p = x[:, :t], mask[:, :t] + lens = mask_p.sum(-1) + span_mask = spans[:, :-1, 1:].ge(0) + span_lens = span_mask.sum((-1, -2)) + span_indices = torch.where(span_mask) + span_labels = spans[:, :-1, 1:][span_indices] + x_span = self.label_embed(span_labels) + x_span += x[span_indices[0], span_indices[1]] + x[span_indices[0], span_indices[2]] + node_lens = lens + span_lens + adj_mask = node_lens.unsqueeze(-1).gt(x.new_tensor(range(node_lens.max()))) + x_mask = lens.unsqueeze(-1).gt(x.new_tensor(range(adj_mask.shape[-1]))) + span_mask = ~x_mask & adj_mask + # concatenate terminals and spans + x_tree = x.new_zeros(*adj_mask.shape, x.shape[-1]).masked_scatter_(x_mask.unsqueeze(-1), x_p[mask_p]) + x_tree = x_tree.masked_scatter_(span_mask.unsqueeze(-1), x_span) + adj = mask.new_zeros(*x_tree.shape[:-1], x_tree.shape[1]) + adj_spans = lens.new_tensor(range(x_tree.shape[1])).view(1, 1, -1).repeat(2, x.shape[0], 1) + adj_spans = adj_spans.masked_scatter_(span_mask.unsqueeze(0), torch.stack(span_indices[1:])) + adj_l, adj_r, adj_w = *adj_spans.unbind(), adj_spans[1] - adj_spans[0] + adj_parent = adj_l.unsqueeze(-1).ge(adj_l.unsqueeze(-2)) & adj_r.unsqueeze(-1).le(adj_r.unsqueeze(-2)) + # set the parent of root as itself + adj_parent.diagonal(0, 1, 2).copy_(adj_w.eq(t - 1)) + adj_parent = adj_parent & span_mask.unsqueeze(1) + # closet ancestor spans as parents + adj_parent = (adj_w.unsqueeze(-2) - adj_w.unsqueeze(-1)).masked_fill_(~adj_parent, t).argmin(-1) + adj.scatter_(-1, adj_parent.unsqueeze(-1), 1) + adj = (adj | adj.transpose(-1, -2)).float() + x_tree = self.gnn_layers(x_tree, adj, adj_mask) + span_mask = span_mask.masked_scatter(span_mask, span_indices[2].eq(t - 1)) + span_lens = span_mask.sum(-1) + x_tree, span_mask = x_tree[span_mask], span_lens.unsqueeze(-1).gt(x.new_tensor(range(span_lens.max()))) + x_span = x.new_zeros(*span_mask.shape, x.shape[-1]).masked_scatter_(span_mask.unsqueeze(-1), x_tree) + return x_span + + + def pos_loss(self, pos_logits: torch.Tensor, pos_tags: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor: + """ + Args: + pos_logits (~torch.Tensor): [batch_size, seq_len, n_tags]. + pos_tags (~torch.LongTensor): [batch_size, seq_len]. + mask (~torch.BoolTensor): [batch_size, seq_len]. + + Returns: + torch.Tensor: The POS tagging loss. + """ + loss_fn = nn.CrossEntropyLoss() + return loss_fn(pos_logits[mask], pos_tags[mask]) + + def decode_pos(self, s_tag: torch.Tensor): + """ + Decode the most likely POS tags. + + Args: + pos_logits (~torch.Tensor): [batch_size, seq_len, n_tags] + mask (~torch.BoolTensor): [batch_size, seq_len] + + Returns: + List[List[int]]: POS tags per token for each sentence in the batch. + """ + pos_preds = pos_logits.argmax(-1) + #return [seq[mask[i]].tolist() for i, seq in enumerate(pos_preds)] + return pos_preds diff --git a/tania_scripts/supar/models/const/aj/parser.py b/tania_scripts/supar/models/const/aj/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..543e8284b90805f045e1d5de2b15aa7c75ffc417 --- /dev/null +++ b/tania_scripts/supar/models/const/aj/parser.py @@ -0,0 +1,534 @@ +# -*- coding: utf-8 -*- + +import os +from typing import Dict, Iterable, Set, Union + +import torch + +from supar.models.const.aj.model import AttachJuxtaposeConstituencyModel, AttachJuxtaposeConstituencyModelPos +from supar.models.const.aj.transform import AttachJuxtaposeTree +from supar.parser import Parser +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, EOS, NUL, PAD, UNK +from supar.utils.field import Field, RawField, SubwordField +from supar.utils.logging import get_logger +from supar.utils.metric import SpanMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.utils.transform import Batch +from torch.nn.utils.rnn import pad_sequence + + +logger = get_logger(__name__) + + +def compute_pos_accuracy(pos_gold, pos_preds): + correct = 0 + total = 0 + for gold_seq, pred_seq in zip(pos_gold, pos_preds): + for g, p in zip(gold_seq, pred_seq): + if g == p: + correct += 1 + total += len(gold_seq) + return correct, total + + + + +class AttachJuxtaposeConstituencyParser(Parser): + r""" + The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`. + """ + + NAME = 'attach-juxtapose-constituency' + MODEL = AttachJuxtaposeConstituencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.TREE = self.transform.TREE + self.NODE = self.transform.NODE + self.PARENT = self.transform.PARENT + self.NEW = self.transform.NEW + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + print("here") + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + verbose: bool = True, + **kwargs + ): + + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + #print("TRAIN STEP") + words, *feats, trees, nodes, parents, news = batch + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + #print("s_tag", s_tag) + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> SpanMetric: + #print("EVAL STEP") + words, *feats, trees, nodes, parents, news = batch + #print("WORDS", words.shape, words) + mask = batch.mask[:, (2+self.args.delay):] + x, qloss = self.model(words, feats) + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + #print("CHART PREDS") + #print("self new vocab", self.NEW.vocab.items()) + #print() + preds = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL}) + for tree, chart in zip(trees, chart_preds)] + + for tree, chart in zip(trees, chart_preds): + print(tree, chart) + + print() + for tree in trees: + print("ORIG TREE", tree) + print() + for tree in preds: + print("PRED TREE", tree) + + print("=========================") + + return SpanMetric(loss, + [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds], + [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees]) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, *feats, trees = batch + mask = batch.mask[:, (2+self.args.delay):] + x, _ = self.model(words, feats) + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + chart_preds = chart_preds[0] + batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL}) + for tree, chart in zip(trees, chart_preds)] + if self.args.prob: + raise NotImplementedError("Returning action probs are currently not supported yet.") + + new_tokenwise_preds = [] + for k, y in chart_preds[1:]: + new_tokenwise_preds.append([k, y[0], self.NEW.vocab[y[1]], self.NEW.vocab[y[2]]]) + + chart_preds = [[[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0]]] + new_tokenwise_preds + #for x in chart_preds[1:]: + # new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]] + + return chart_preds #batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + r""" + Build a brand-new Parser, including initialization of all data fields and model parameters. + + Args: + path (str): + The path of the model to be saved. + min_freq (str): + The minimum frequency needed to include a token in the vocabulary. Default: 2. + fix_len (int): + The max length of all subword pieces. The excess part of each piece will be truncated. + Required if using CharLSTM/BERT. + Default: 20. + kwargs (Dict): + A dict holding the unconsumed arguments. + """ + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + TAG, CHAR = None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + pad_token = t.pad if t.pad else PAD + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay) + TAG = Field('tags', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay) + TREE = RawField('trees') + NODE, PARENT, NEW = Field('node', use_vocab=False), Field('parent', unk=UNK), Field('new', unk=UNK) + transform = AttachJuxtaposeTree(WORD=(WORD, CHAR), POS=TAG, TREE=TREE, NODE=NODE, PARENT=PARENT, NEW=NEW) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if CHAR is not None: + CHAR.build(train) + TAG.build(train) + PARENT, NEW = PARENT.build(train), NEW.build(train) + PARENT.vocab = NEW.vocab.update(PARENT.vocab) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_labels': len(NEW.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index, + 'eos_index': WORD.eos_index, + 'nul_index': NEW.vocab[NUL] + }) + logger.info(f"{transform}") + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser + +def flatten(x): + result = [] + for el in x: + if isinstance(el, list): + result.extend(flatten(el)) + else: + result.append(el) + return result + + + + +class AttachJuxtaposeConstituencyParserPos(Parser): + r""" + The implementation of AttachJuxtapose Constituency Parser :cite:`yang-deng-2020-aj`. + """ + + NAME = 'attach-juxtapose-constituency' + MODEL = AttachJuxtaposeConstituencyModelPos + + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.TREE = self.transform.TREE + self.NODE = self.transform.NODE + self.PARENT = self.transform.PARENT + self.NEW = self.transform.NEW + self.TAG = self.transform.POS + + #print(self.TAG) + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + beam_size: int = 1, + verbose: bool = True, + **kwargs + ): + + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + #print("TRAIN STEP") + words, *feats, postags, nodes, parents, news = batch + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + tags = feats[-1] + #loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + pos_logits + #loss = loss.mean() + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tags[:, 1:-1]) + qloss + + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> SpanMetric: + + #print("EVAL STEP") + words, *feats, trees, nodes, parents, news = batch + + tags = feats[-1] + #loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask) + qloss + pos_logits + #loss = loss.mean() + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tags[:, 1:-1]) + qloss + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + #print('CHART PREDS', chart_preds[0]) + + pos_preds = [[self.TAG.vocab[p.item()] for p in sent] for sent in s_tag.argmax(-1)] + trimmed_tags = [tag_seq[1:-1] for tag_seq in tags] + trimmed_tags_list = [x.cpu().tolist() for x in trimmed_tags] + pos_gold = [[self.TAG.vocab[p] for p in sent] for sent in trimmed_tags_list] + + pos_correct, pos_total = compute_pos_accuracy(pos_gold, pos_preds) + pos_acc = pos_correct/pos_total + print(pos_acc) + + """ + tags = [*feats[-1:]] + trimmed_tags = [tag_seq[1:-1] for tag_seq in tags[0]] + gtags = [*feats[-1:]][0].cpu().tolist() + gtags = [x[1:-1] for x in gtags] + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + + tag_tensor = pad_sequence(trimmed_tags, batch_first=True, padding_value=self.TAG.pad_index).to(s_tag.device) + + assert s_tag.shape[:2] == tag_tensor.shape[:2] + print("s_tag shape", s_tag.shape[:2], "tag_tensor shape", tag_tensor.shape[:2]) + + + loss = self.model.loss(x[:, 1:-1], nodes, parents, news, mask, s_tag, tag_tensor) + qloss + + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + #pos_preds = [self.TAG.vocab[x] for x in s_tag.argmax(-1)] + + for tag_seq in gtags: + for pos in tag_seq: + assert 0 <= pos < len(self.TAG.vocab) + + pos_preds = [[self.TAG.vocab[p.item()] for p in sent] for sent in s_tag.argmax(-1)] + pos_gold = [[self.TAG.vocab[pos] for pos in tag_seq] for tag_seq in gtags] + + #print("357!", len(pos_gold), len(pos_preds)) + pos_correct, pos_total = compute_pos_accuracy(pos_gold, pos_preds) + print("358!", pos_correct, pos_total, pos_correct/pos_total) + + + #print('%%% POS preds', pos_preds) + #print('%%%% POS gold', pos_gold) + """ + + #batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL}) + # for tree, chart in zip(trees, chart_preds)] + #print(batch.trees) + + preds = [AttachJuxtaposeTree.build(tree, [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart], {UNK, NUL}) + for tree, chart in zip(trees, chart_preds[0])] + #print('POS preds', s_tag.argmax(-1)) + #print("PREDS", len(preds)) + + + """ + span_preds = [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0][0]] + #print("new SPAN preds", span_preds) + verbalized_preds = [] + + + for x in chart_preds[1:]: + new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]] + #print(new_item) + verbalized_preds.append(new_item) + """ + + print("424", loss, pos_acc) + return SpanMetric(loss, + [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds], + [AttachJuxtaposeTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees], pos_acc) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, *feats, trees = batch + mask = batch.mask[:, (2+self.args.delay):] + x, s_tag, qloss = self.model(words, feats) + chart_preds = self.model.decode(x[:, 1:-1], mask, self.args.beam_size) + #print('VHART PREDS: ', chart_preds) + #print('LEN VHART PREDS: ', len(chart_preds)) + + chart_preds = chart_preds[0] + batch.trees = [AttachJuxtaposeTree.build(tree, [(i, j, self.NEW.vocab[label]) for i, j, label in chart], {UNK, NUL}) + for tree, chart in zip(trees, chart_preds)] + if self.args.prob: + raise NotImplementedError("Returning action probs are currently not supported yet.") + + new_tokenwise_preds = [] + #for pred in chart_preds: + # print("PRED", pred) + # new_tokenwise_preds.append([k, y[0], self.NEW.vocab[y[1]], self.NEW.vocab[y[2]]]) + + + span_preds = [[x[0], x[1], self.NEW.vocab[x[2]]] for x in chart_preds[0]] + new_tokenwise_preds + #print("new chart preds", span_preds) + verbalized_preds = [] + for x in chart_preds[1:]: + new_item = [x[0], [x[1][0], self.NEW.vocab[x[1][1]], self.NEW.vocab[x[1][2]]]] + #print(new_item) + verbalized_preds.append(new_item) + #print('POS preds', s_tag.argmax(-1)) + pos_preds = [self.TAG.vocab[x] for x in s_tag.argmax(-1)] + return span_preds, verbalized_preds, pos_preds #batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + r""" + Build a brand-new Parser, including initialization of all data fields and model parameters. + + Args: + path (str): + The path of the model to be saved. + min_freq (str): + The minimum frequency needed to include a token in the vocabulary. Default: 2. + fix_len (int): + The max length of all subword pieces. The excess part of each piece will be truncated. + Required if using CharLSTM/BERT. + Default: 20. + kwargs (Dict): + A dict holding the unconsumed arguments. + """ + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + TAG, CHAR = None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + pad_token = t.pad if t.pad else PAD + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay) + TAG = Field('tags', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay) + TREE = RawField('trees') + NODE, PARENT, NEW = Field('node', use_vocab=False), Field('parent', unk=UNK), Field('new', unk=UNK) + transform = AttachJuxtaposeTree(WORD=(WORD, CHAR), POS=TAG, TREE=TREE, NODE=NODE, PARENT=PARENT, NEW=NEW) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if CHAR is not None: + CHAR.build(train) + TAG.build(train) + #print("203 TAG VOCAB", [x for x in TAG.vocab.items()]) + PARENT, NEW = PARENT.build(train), NEW.build(train) + PARENT.vocab = NEW.vocab.update(PARENT.vocab) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_labels': len(NEW.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index, + 'eos_index': WORD.eos_index, + 'nul_index': NEW.vocab[NUL] + }) + logger.info(f"{transform}") + + #print('TAG VOCAB', TAG.vocab.items()) + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser \ No newline at end of file diff --git a/tania_scripts/supar/models/const/aj/transform.py b/tania_scripts/supar/models/const/aj/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..56e0f5158b6b3b214b7e99e8a479fbabc56bcbf6 --- /dev/null +++ b/tania_scripts/supar/models/const/aj/transform.py @@ -0,0 +1,459 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union + +import nltk +import torch + +from supar.models.const.crf.transform import Tree +from supar.utils.common import NUL +from supar.utils.logging import get_logger +from supar.utils.tokenizer import Tokenizer +from supar.utils.transform import Sentence + +if TYPE_CHECKING: + from supar.utils import Field + +logger = get_logger(__name__) + + +class AttachJuxtaposeTree(Tree): + r""" + :class:`AttachJuxtaposeTree` is derived from the :class:`Tree` class, + supporting back-and-forth transformations between trees and AttachJuxtapose actions :cite:`yang-deng-2020-aj`. + + Attributes: + WORD: + Words in the sentence. + POS: + Part-of-speech tags, or underscores if not available. + TREE: + The raw constituency tree in :class:`nltk.tree.Tree` format. + NODE: + The target node on each rightmost chain. + PARENT: + The label of the parent node of each terminal. + NEW: + The label of each newly inserted non-terminal with a target node and a terminal as juxtaposed children. + ``NUL`` represents the `Attach` action. + """ + + fields = ['WORD', 'POS', 'TREE', 'NODE', 'PARENT', 'NEW'] + + def __init__( + self, + WORD: Optional[Union[Field, Iterable[Field]]] = None, + POS: Optional[Union[Field, Iterable[Field]]] = None, + TREE: Optional[Union[Field, Iterable[Field]]] = None, + NODE: Optional[Union[Field, Iterable[Field]]] = None, + PARENT: Optional[Union[Field, Iterable[Field]]] = None, + NEW: Optional[Union[Field, Iterable[Field]]] = None + ) -> Tree: + super().__init__() + + self.WORD = WORD + self.POS = POS + self.TREE = TREE + self.NODE = NODE + self.PARENT = PARENT + self.NEW = NEW + + @property + def src(self): + return self.WORD, self.POS, self.TREE + + @property + def tgt(self): + return self.NODE, self.PARENT, self.NEW + + @classmethod + def tree2action(cls, tree: nltk.Tree): + r""" + Converts a constituency tree into AttachJuxtapose actions. + + Args: + tree (nltk.tree.Tree): + A constituency tree in :class:`nltk.tree.Tree` format. + + Returns: + A sequence of AttachJuxtapose actions. + + Examples: + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree + >>> tree = nltk.Tree.fromstring(''' + (TOP + (S + (NP (_ Arthur)) + (VP + (_ is) + (NP (NP (_ King)) (PP (_ of) (NP (_ the) (_ Britons))))) + (_ .))) + ''') + >>> tree.pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + >>> AttachJuxtaposeTree.tree2action(tree) + [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'), + (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'), + (0, '<nul>', '<nul>')] + """ + + def isroot(node): + return node == tree[0] + + def isterminal(node): + return len(node) == 1 and not isinstance(node[0], nltk.Tree) + + def last_leaf(node): + pos = () + while True: + pos += (len(node) - 1,) + node = node[-1] + if isterminal(node): + return node, pos + + def parent(position): + return tree[position[:-1]] + + def grand(position): + return tree[position[:-2]] + + def detach(tree): + last, last_pos = last_leaf(tree) + siblings = parent(last_pos)[:-1] + + if len(siblings) > 0: + last_subtree = last + last_subtree_siblings = siblings + parent_label = NUL + else: + last_subtree, last_pos = parent(last_pos), last_pos[:-1] + last_subtree_siblings = [] if isroot(last_subtree) else parent(last_pos)[:-1] + parent_label = last_subtree.label() + + target_pos, new_label, last_tree = 0, NUL, tree + if isroot(last_subtree): + last_tree = None + + elif len(last_subtree_siblings) == 1 and not isterminal(last_subtree_siblings[0]): + new_label = parent(last_pos).label() + new_label = new_label + target = last_subtree_siblings[0] + last_grand = grand(last_pos) + if last_grand is None: + last_tree = targetistermina + else: + last_grand[-1] = target + target_pos = len(last_pos) - 2 + else: + target = parent(last_pos) + target.pop() + target_pos = len(last_pos) - 2 + action = target_pos, parent_label, new_label + return action, last_tree + if tree is None: + return [] + action, last_tree = detach(tree) + return cls.tree2action(last_tree) + [action] + + @classmethod + def action2tree( + cls, + tree: nltk.Tree, + actions: List[Tuple[int, str, str]], + join: str = '::', + ) -> nltk.Tree: + r""" + Recovers a constituency tree from a sequence of AttachJuxtapose actions. + + Args: + tree (nltk.tree.Tree): + An empty tree that provides a base for building a result tree. + actions (List[Tuple[int, str, str]]): + A sequence of AttachJuxtapose actions. + join (str): + A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains. + Default: ``'::'``. + + Returns: + A result constituency tree. + + Examples: + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree + >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP') + >>> AttachJuxtaposeTree.action2tree(tree, + [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'), + (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'), + (0, '<nul>', '<nul>')]).pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + """ + + def target(node, depth): + node_pos = () + for _ in range(depth): + node_pos += (len(node) - 1,) + node = node[-1] + return node, node_pos + + def parent(tree, position): + return tree[position[:-1]] + + def execute(tree: nltk.Tree, terminal: Tuple(str, str), action: Tuple[int, str, str]) -> nltk.Tree: + target_pos, parent_label, new_label, post = action + #print(target_pos, parent_label, new_label) + new_leaf = nltk.Tree(post, [terminal[0]]) + + # create the subtree to be inserted + new_subtree = new_leaf if parent_label == NUL else nltk.Tree(parent_label, [new_leaf]) + # find the target position at which to insert the new subtree + target_node = tree + if target_node is not None: + target_node, target_pos = target(target_node, target_pos) + + # Attach + if new_label == NUL: + # attach the first token + if target_node is None: + return new_subtree + target_node.append(new_subtree) + # Juxtapose + else: + new_subtree = nltk.Tree(new_label, [target_node, new_subtree]) + if len(target_pos) > 0: + parent_node = parent(tree, target_pos) + parent_node[-1] = new_subtree + else: + tree = new_subtree + return tree + + tree, root, terminals = None, tree.label(), tree.pos() + for terminal, action in zip(terminals, actions): + tree = execute(tree, terminal, action) + # recover unary chains + nodes = [tree] + while nodes: + node = nodes.pop() + if isinstance(node, nltk.Tree): + nodes.extend(node) + if join in node.label(): + labels = node.label().split(join) + node.set_label(labels[0]) + subtree = nltk.Tree(labels[-1], node) + for label in reversed(labels[1:-1]): + subtree = nltk.Tree(label, [subtree]) + node[:] = [subtree] + return nltk.Tree(root, [tree]) + + @classmethod + def action2span( + cls, + action: torch.Tensor, + spans: torch.Tensor = None, + nul_index: int = -1, + mask: torch.BoolTensor = None + ) -> torch.Tensor: + r""" + Converts a batch of the tensorized action at a given step into spans. + + Args: + action (~torch.Tensor): ``[3, batch_size]``. + A batch of the tensorized action at a given step, containing indices of target nodes, parent and new labels. + spans (~torch.Tensor): + Spans generated at previous steps, ``None`` at the first step. Default: ``None``. + nul_index (int): + The index for the obj:`NUL` token, representing the Attach action. Default: -1. + mask (~torch.BoolTensor): ``[batch_size]``. + The mask for covering the unpadded tokens. + + Returns: + A tensor representing a batch of spans for the given step. + + Examples: + >>> from collections import Counter + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree, Vocab + >>> from supar.utils.common import NUL + >>> nodes, parents, news = zip(*[(0, 'NP', NUL), (0, 'VP', 'S'), (1, 'NP', NUL), + (2, 'PP', 'NP'), (3, 'NP', NUL), (4, NUL, NUL), + (0, NUL, NUL)]) + >>> vocab = Vocab(Counter(sorted(set([*parents, *news])))) + >>> actions = torch.tensor([nodes, vocab[parents], vocab[news]]).unsqueeze(1) + >>> spans = None + >>> for action in actions.unbind(-1): + ... spans = AttachJuxtaposeTree.action2span(action, spans, vocab[NUL]) + ... + >>> spans + tensor([[[-1, 1, -1, -1, -1, -1, -1, 3], + [-1, -1, -1, -1, -1, -1, 4, -1], + [-1, -1, -1, 1, -1, -1, 1, -1], + [-1, -1, -1, -1, -1, -1, 2, -1], + [-1, -1, -1, -1, -1, -1, 1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1]]]) + >>> sequence = torch.where(spans.ge(0)) + >>> sequence = list(zip(sequence[1].tolist(), sequence[2].tolist(), vocab[spans[sequence]])) + >>> sequence + [(0, 1, 'NP'), (0, 7, 'S'), (1, 6, 'VP'), (2, 3, 'NP'), (2, 6, 'NP'), (3, 6, 'PP'), (4, 6, 'NP')] + >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP') + >>> AttachJuxtaposeTree.build(tree, sequence).pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + + """ + + # [batch_size] + target, parent, new = action + if spans is None: + spans = action.new_full((action.shape[1], 2, 2), -1) + spans[:, 0, 1] = parent + return spans + if mask is None: + mask = torch.ones_like(target, dtype=bool) + juxtapose_mask = new.ne(nul_index) & mask + # ancestor nodes are those on the rightmost chain and higher than the target node + # [batch_size, seq_len] + rightmost_mask = spans[..., -1].ge(0) + ancestors = rightmost_mask.cumsum(-1).masked_fill_(~rightmost_mask, -1) - 1 + # should not include the target node for the Juxtapose action + ancestor_mask = mask.unsqueeze(-1) & ancestors.ge(0) & ancestors.le((target - juxtapose_mask.long()).unsqueeze(-1)) + target_pos = torch.where(ancestors.eq(target.unsqueeze(-1))[juxtapose_mask])[-1] + # the right boundaries of ancestor nodes should be aligned with the new generated terminals + spans = torch.cat((spans, torch.where(ancestor_mask, spans[..., -1], -1).unsqueeze(-1)), -1) + spans[..., -2].masked_fill_(ancestor_mask, -1) + spans[juxtapose_mask, target_pos, -1] = new.masked_fill(new.eq(nul_index), -1)[juxtapose_mask] + spans[mask, -1, -1] = parent.masked_fill(parent.eq(nul_index), -1)[mask] + # [batch_size, seq_len+1, seq_len+1] + spans = torch.cat((spans, torch.full_like(spans[:, :1], -1)), 1) + return spans + + def load( + self, + data: Union[str, Iterable], + lang: Optional[str] = None, + **kwargs + ) -> List[AttachJuxtaposeTreeSentence]: + r""" + Args: + data (Union[str, Iterable]): + A filename or a list of instances. + lang (str): + Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. + ``None`` if tokenization is not required. + Default: ``None``. + + Returns: + A list of :class:`AttachJuxtaposeTreeSentence` instances. + """ + + if lang is not None: + tokenizer = Tokenizer(lang) + if isinstance(data, str) and os.path.exists(data): + if data.endswith('.txt'): + data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1) + else: + data = open(data) + else: + if lang is not None: + data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)] + else: + data = [data] if isinstance(data[0], str) else data + + index = 0 + for s in data: + + try: + tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root) + sentence = AttachJuxtaposeTreeSentence(self, tree, index) + except ValueError: + logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!") + continue + except IndexError: + tree = nltk.Tree.fromstring('(S ' + s + ')') + sentence = AttachJuxtaposeTreeSentence(self, tree, index) + else: + yield sentence + index += 1 + self.root = tree.label() + + +class AttachJuxtaposeTreeSentence(Sentence): + r""" + Args: + transform (AttachJuxtaposeTree): + A :class:`AttachJuxtaposeTree` object. + tree (nltk.tree.Tree): + A :class:`nltk.tree.Tree` object. + index (Optional[int]): + Index of the sentence in the corpus. Default: ``None``. + """ + + def __init__( + self, + transform: AttachJuxtaposeTree, + tree: nltk.Tree, + index: Optional[int] = None + ) -> AttachJuxtaposeTreeSentence: + super().__init__(transform, index) + + words, tags = zip(*tree.pos()) + nodes, parents, news = None, None, None + if transform.training: + oracle_tree = tree.copy(True) + # the root node must have a unary chain + if len(oracle_tree) > 1: + oracle_tree[:] = [nltk.Tree('*', oracle_tree)] + oracle_tree.collapse_unary(joinChar='::') + if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree): + oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]]) + nodes, parents, news = zip(*transform.tree2action(oracle_tree)) + tags = [x.split("##")[0] for x in tags] + self.values = [words, tags, tree, nodes, parents, news] + + def __repr__(self): + return self.values[-4].pformat(1000000) + + def pretty_print(self): + self.values[-4].pretty_print() diff --git a/tania_scripts/supar/models/const/crf/__init__.py b/tania_scripts/supar/models/const/crf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a1e583e5a598110b37f745bb68b7f473f20149 --- /dev/null +++ b/tania_scripts/supar/models/const/crf/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import CRFConstituencyModel +from .parser import CRFConstituencyParser + +__all__ = ['CRFConstituencyModel', 'CRFConstituencyParser'] diff --git a/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b23ae21ab8db14bfbcb97e155b4203d8c61d82a Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ebb3eb81e33d0ac6cbfe3c6b1f467617348d959 Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d01854d3d9e7b162c4bbd2903109ba1d1b0384a1 Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ba33210adee6c2b20457972f93ac9ce0b2e68e0 Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7e69785c24d5aba0b6e9363ddd28806b9f04cfa Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7282c0aa3c3e6e0aa427723393109f9b057bba50 Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fec7bc1c3fc13a4209bb5c3f1be993b016f7cd7a Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..284c610e9e875040233cc1a2d0138345a2c0e4ae Binary files /dev/null and b/tania_scripts/supar/models/const/crf/__pycache__/transform.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/crf/model.py b/tania_scripts/supar/models/const/crf/model.py new file mode 100644 index 0000000000000000000000000000000000000000..03199e80df175292e83628962f576c11f9bf83be --- /dev/null +++ b/tania_scripts/supar/models/const/crf/model.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from supar.model import Model +from supar.modules import MLP, Biaffine +from supar.structs import ConstituencyCRF +from supar.utils import Config + + +class CRFConstituencyModel(Model): + r""" + The implementation of CRF Constituency Parser :cite:`zhang-etal-2020-fast`, + also called FANCY (abbr. of Fast and Accurate Neural Crf constituencY) Parser. + + Args: + n_words (int): + The size of the word vocabulary. + n_labels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layer. Default: .33. + n_span_mlp (int): + Span MLP size. Default: 500. + n_label_mlp (int): + Label MLP size. Default: 100. + mlp_dropout (float): + The dropout ratio of MLP layers. Default: .33. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_labels, + n_tags=None, + n_chars=None, + encoder='lstm', + feat=['char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, True), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_span_mlp=500, + n_label_mlp=100, + mlp_dropout=.33, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + self.span_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout) + self.span_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout) + self.label_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout) + self.label_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout) + + self.span_attn = Biaffine(n_in=n_span_mlp, bias_x=True, bias_y=False) + self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True) + self.criterion = nn.CrossEntropyLoss() + + def forward(self, words, feats=None): + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first tensor of shape ``[batch_size, seq_len, seq_len]`` holds scores of all possible constituents. + The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds + scores of all possible labels on each constituent. + """ + + x = self.encode(words, feats) + + x_f, x_b = x.chunk(2, -1) + x = torch.cat((x_f[:, :-1], x_b[:, 1:]), -1) + + span_l = self.span_mlp_l(x) + span_r = self.span_mlp_r(x) + label_l = self.label_mlp_l(x) + label_r = self.label_mlp_r(x) + + # [batch_size, seq_len, seq_len] + s_span = self.span_attn(span_l, span_r) + # [batch_size, seq_len, seq_len, n_labels] + s_label = self.label_attn(label_l, label_r).permute(0, 2, 3, 1) + + return s_span, s_label + + def loss(self, s_span, s_label, charts, mask, mbr=True): + r""" + Args: + s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all constituents. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all constituent labels. + charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. + The tensor of gold-standard labels. Positions without labels are filled with -1. + mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + The mask for covering the unpadded tokens in each chart. + mbr (bool): + If ``True``, returns marginals for MBR decoding. Default: ``True``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The training loss and original constituent scores + of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise. + """ + + span_mask = charts.ge(0) & mask + span_dist = ConstituencyCRF(s_span, mask[:, 0].sum(-1)) + span_loss = -span_dist.log_prob(charts).sum() / mask[:, 0].sum() + span_probs = span_dist.marginals if mbr else s_span + label_loss = self.criterion(s_label[span_mask], charts[span_mask]) + loss = span_loss + label_loss + + return loss, span_probs + + def decode(self, s_span, s_label, mask): + r""" + Args: + s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all constituents. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all constituent labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + The mask for covering the unpadded tokens in each chart. + + Returns: + List[List[Tuple]]: + Sequences of factorized labeled trees. + """ + + span_preds = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).argmax + label_preds = s_label.argmax(-1).tolist() + return [[(i, j, labels[i][j]) for i, j in spans] for spans, labels in zip(span_preds, label_preds)] diff --git a/tania_scripts/supar/models/const/crf/parser.py b/tania_scripts/supar/models/const/crf/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..ad5bd16b2da4d46f156b76fa4d64d97c5e88bcd9 --- /dev/null +++ b/tania_scripts/supar/models/const/crf/parser.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- + +import os +from typing import Dict, Iterable, Set, Union + +import torch + +from supar.models.const.crf.model import CRFConstituencyModel +from supar.models.const.crf.transform import Tree +from supar.parser import Parser +from supar.structs import ConstituencyCRF +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, EOS, PAD, UNK +from supar.utils.field import ChartField, Field, RawField, SubwordField +from supar.utils.logging import get_logger +from supar.utils.metric import SpanMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class CRFConstituencyParser(Parser): + r""" + The implementation of CRF Constituency Parser :cite:`zhang-etal-2020-fast`. + """ + + NAME = 'crf-constituency' + MODEL = CRFConstituencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.TREE = self.transform.TREE + self.CHART = self.transform.CHART + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + mbr: bool = True, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + mbr: bool = True, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + mbr: bool = True, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, *feats, _, charts = batch + mask = batch.mask[:, 1:] + mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) + s_span, s_label = self.model(words, feats) + loss, _ = self.model.loss(s_span, s_label, charts, mask, self.args.mbr) + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> SpanMetric: + words, *feats, trees, charts = batch + mask = batch.mask[:, 1:] + mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) + s_span, s_label = self.model(words, feats) + loss, s_span = self.model.loss(s_span, s_label, charts, mask, self.args.mbr) + chart_preds = self.model.decode(s_span, s_label, mask) + preds = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) + for tree, chart in zip(trees, chart_preds)] + return SpanMetric(loss, + [Tree.factorize(tree, self.args.delete, self.args.equal) for tree in preds], + [Tree.factorize(tree, self.args.delete, self.args.equal) for tree in trees]) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, *feats, trees = batch + mask, lens = batch.mask[:, 1:], batch.lens - 2 + mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) + s_span, s_label = self.model(words, feats) + s_span = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).marginals if self.args.mbr else s_span + chart_preds = self.model.decode(s_span, s_label, mask) + batch.trees = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) + for tree, chart in zip(trees, chart_preds)] + if self.args.prob: + batch.probs = [prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)] + return batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + r""" + Build a brand-new Parser, including initialization of all data fields and model parameters. + + Args: + path (str): + The path of the model to be saved. + min_freq (str): + The minimum frequency needed to include a token in the vocabulary. Default: 2. + fix_len (int): + The max length of all subword pieces. The excess part of each piece will be truncated. + Required if using CharLSTM/BERT. + Default: 20. + kwargs (Dict): + A dict holding the unconsumed arguments. + """ + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + TAG, CHAR, ELMO, BERT = None, None, None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True) + if 'tag' in args.feat: + TAG = Field('tags', bos=BOS, eos=EOS) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len) + if 'elmo' in args.feat: + from allennlp.modules.elmo import batch_to_ids + ELMO = RawField('elmo') + ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device) + if 'bert' in args.feat: + t = TransformerTokenizer(args.bert) + BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t) + BERT.vocab = t.vocab + TREE = RawField('trees') + CHART = ChartField('charts') + transform = Tree(WORD=(WORD, CHAR, ELMO, BERT), POS=TAG, TREE=TREE, CHART=CHART) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if TAG is not None: + TAG.build(train) + if CHAR is not None: + CHAR.build(train) + CHART.build(train) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_labels': len(CHART.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'bert_pad_index': BERT.pad_index if BERT is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index, + 'eos_index': WORD.eos_index + }) + logger.info(f"{transform}") + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser diff --git a/tania_scripts/supar/models/const/crf/transform.py b/tania_scripts/supar/models/const/crf/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ba03c85e3e46439d62cd2fa13299ac847437b5 --- /dev/null +++ b/tania_scripts/supar/models/const/crf/transform.py @@ -0,0 +1,494 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, + Union) + +import nltk + +from supar.utils.logging import get_logger +from supar.utils.tokenizer import Tokenizer +from supar.utils.transform import Sentence, Transform + +if TYPE_CHECKING: + from supar.utils import Field + +logger = get_logger(__name__) + + +class Tree(Transform): + r""" + A :class:`Tree` object factorize a constituency tree into four fields, + each associated with one or more :class:`~supar.utils.field.Field` objects. + + Attributes: + WORD: + Words in the sentence. + POS: + Part-of-speech tags, or underscores if not available. + TREE: + The raw constituency tree in :class:`nltk.tree.Tree` format. + CHART: + The factorized sequence of binarized tree traversed in post-order. + """ + + root = '' + fields = ['WORD', 'POS', 'TREE', 'CHART'] + + def __init__( + self, + WORD: Optional[Union[Field, Iterable[Field]]] = None, + POS: Optional[Union[Field, Iterable[Field]]] = None, + TREE: Optional[Union[Field, Iterable[Field]]] = None, + CHART: Optional[Union[Field, Iterable[Field]]] = None + ) -> Tree: + super().__init__() + + self.WORD = WORD + self.POS = POS + self.TREE = TREE + self.CHART = CHART + + @property + def src(self): + return self.WORD, self.POS, self.TREE + + @property + def tgt(self): + return self.CHART, + + @classmethod + def totree( + cls, + tokens: List[Union[str, Tuple]], + root: str = '', + normalize: Dict[str, str] = {'(': '-LRB-', ')': '-RRB-'} + ) -> nltk.Tree: + r""" + Converts a list of tokens to a :class:`nltk.tree.Tree`, with missing fields filled in with underscores. + + Args: + tokens (List[Union[str, Tuple]]): + This can be either a list of words or word/pos pairs. + root (str): + The root label of the tree. Default: ''. + normalize (Dict): + Keys within the dict in each token will be replaced by the values. Default: ``{'(': '-LRB-', ')': '-RRB-'}``. + + Returns: + A :class:`nltk.tree.Tree` object. + + Examples: + >>> from supar.models.const.crf.transform import Tree + >>> Tree.totree(['She', 'enjoys', 'playing', 'tennis', '.'], 'TOP').pprint() + (TOP ( (_ She)) ( (_ enjoys)) ( (_ playing)) ( (_ tennis)) ( (_ .))) + >>> Tree.totree(['(', 'If', 'You', 'Let', 'It', ')'], 'TOP').pprint() + (TOP + ( (_ -LRB-)) + ( (_ If)) + ( (_ You)) + ( (_ Let)) + ( (_ It)) + ( (_ -RRB-))) + """ + + normalize = str.maketrans(normalize) + if isinstance(tokens[0], str): + tokens = [(token, '_') for token in tokens] + return nltk.Tree(root, [nltk.Tree('', [nltk.Tree(pos, [word.translate(normalize)])]) for word, pos in tokens]) + + @classmethod + def binarize( + cls, + tree: nltk.Tree, + left: bool = True, + mark: str = '*', + join: str = '::', + implicit: bool = False + ) -> nltk.Tree: + r""" + Conducts binarization over the tree. + + First, the tree is transformed to satisfy `Chomsky Normal Form (CNF)`_. + Here we call :meth:`~nltk.tree.Tree.chomsky_normal_form` to conduct left-binarization. + Second, all unary productions in the tree are collapsed. + + Args: + tree (nltk.tree.Tree): + The tree to be binarized. + left (bool): + If ``True``, left-binarization is conducted. Default: ``True``. + mark (str): + A string used to mark newly inserted nodes, working if performing explicit binarization. Default: ``'*'``. + join (str): + A string used to connect collapsed node labels. Default: ``'::'``. + implicit (bool): + If ``True``, performs implicit binarization. Default: ``False``. + + Returns: + The binarized tree. + + Examples: + >>> from supar.models.const.crf.transform import Tree + >>> tree = nltk.Tree.fromstring(''' + (TOP + (S + (NP (_ She)) + (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis))))) + (_ .))) + ''') + >>> tree.pretty_print() + TOP + | + S + ____________|________________ + | VP | + | _______|_____ | + | | S | + | | | | + | | VP | + | | _____|____ | + NP | | NP | + | | | | | + _ _ _ _ _ + | | | | | + She enjoys playing tennis . + + >>> Tree.binarize(tree).pretty_print() + TOP + | + S + _____|__________________ + S* | + __________|_____ | + | VP | + | ___________|______ | + | | S::VP | + | | ______|_____ | + NP VP* VP* NP S* + | | | | | + _ _ _ _ _ + | | | | | + She enjoys playing tennis . + + >>> Tree.binarize(tree, implicit=True).pretty_print() + TOP + | + S + _____|__________________ + | + __________|_____ | + | VP | + | ___________|______ | + | | S::VP | + | | ______|_____ | + NP NP + | | | | | + _ _ _ _ _ + | | | | | + She enjoys playing tennis . + + >>> Tree.binarize(tree, left=False).pretty_print() + TOP + | + S + ____________|______ + | S* + | ______|___________ + | VP | + | _______|______ | + | | S::VP | + | | ______|_____ | + NP VP* VP* NP S* + | | | | | + _ _ _ _ _ + | | | | | + She enjoys playing tennis . + + .. _Chomsky Normal Form (CNF): + https://en.wikipedia.org/wiki/Chomsky_normal_form + """ + + tree = tree.copy(True) + nodes = [tree] + if len(tree) == 1: + if not isinstance(tree[0][0], nltk.Tree): + tree[0] = nltk.Tree(f'{tree.label()}{mark}', [tree[0]]) + nodes = [tree[0]] + while nodes: + node = nodes.pop() + if isinstance(node, nltk.Tree): + if implicit: + label = '' + else: + label = node.label() + if mark not in label: + label = f'{label}{mark}' + # ensure that only non-terminals can be attached to a n-ary subtree + if len(node) > 1: + for child in node: + if not isinstance(child[0], nltk.Tree): + child[:] = [nltk.Tree(child.label(), child[:])] + child.set_label(label) + # chomsky normal form factorization + if len(node) > 2: + if left: + node[:-1] = [nltk.Tree(label, node[:-1])] + else: + node[1:] = [nltk.Tree(label, node[1:])] + nodes.extend(node) + # collapse unary productions, shoule be conducted after binarization + tree.collapse_unary(joinChar=join) + return tree + + @classmethod + def factorize( + cls, + tree: nltk.Tree, + delete_labels: Optional[Set[str]] = None, + equal_labels: Optional[Dict[str, str]] = None + ) -> Iterable[Tuple]: + r""" + Factorizes the tree into a sequence traversed in post-order. + + Args: + tree (nltk.tree.Tree): + The tree to be factorized. + delete_labels (Optional[Set[str]]): + A set of labels to be ignored. This is used for evaluation. + If it is a pre-terminal label, delete the word along with the brackets. + If it is a non-terminal label, just delete the brackets (don't delete children). + In `EVALB`_, the default set is: + {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''} + Default: ``None``. + equal_labels (Optional[Dict[str, str]]): + The key-val pairs in the dict are considered equivalent (non-directional). This is used for evaluation. + The default dict defined in `EVALB`_ is: {'ADVP': 'PRT'} + Default: ``None``. + + Returns: + The sequence of the factorized tree. + + Examples: + >>> from supar.models.const.crf.transform import Tree + >>> tree = nltk.Tree.fromstring(''' + (TOP + (S + (NP (_ She)) + (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis))))) + (_ .))) + ''') + >>> Tree.factorize(tree) + [(0, 1, 'NP'), (3, 4, 'NP'), (2, 4, 'VP'), (2, 4, 'S'), (1, 4, 'VP'), (0, 5, 'S'), (0, 5, 'TOP')] + >>> Tree.factorize(tree, delete_labels={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}) + [(0, 1, 'NP'), (3, 4, 'NP'), (2, 4, 'VP'), (2, 4, 'S'), (1, 4, 'VP'), (0, 5, 'S')] + + .. _EVALB: + https://nlp.cs.nyu.edu/evalb/ + """ + + def track(tree, i): + label = tree if isinstance(tree, str) else tree.label() + if delete_labels is not None and label in delete_labels: + label = None + if equal_labels is not None: + label = equal_labels.get(label, label) + if len(tree) == 1 and not isinstance(tree[0], nltk.Tree): + return (i + 1 if label is not None else i), [] + j, spans = i, [] + for child in tree: + j, s = track(child, j) + spans += s + if label is not None and j > i: + spans = spans + [(i, j, label)] + return j, spans + return track(tree, 0)[1] + + @classmethod + def build( + cls, + sentence: Union[nltk.Tree, Iterable], + spans: Iterable[Tuple], + delete_labels: Optional[Set[str]] = None, + mark: Union[str, Tuple[str]] = ('*', '|<>'), + root: str = '', + join: str = '::', + postorder: bool = True + ) -> nltk.Tree: + r""" + Builds a constituency tree from a span sequence. + During building, the sequence is recovered, i.e., de-binarized to the original format. + + Args: + sentence (Union[nltk.tree.Tree, Iterable]): + Sentence to provide a base for building a result tree, both `nltk.tree.Tree` and tokens are allowed. + spans (Iterable[Tuple]): + A list of spans, each consisting of the indices of left/right boundaries and label of the constituent. + delete_labels (Optional[Set[str]]): + A set of labels to be ignored. Default: ``None``. + mark (Union[str, List[str]]): + A string used to mark newly inserted nodes. Non-terminals containing this will be removed. + Default: ``('*', '|<>')``. + root (str): + The root label of the tree, needed if input a list of tokens. Default: ''. + join (str): + A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains. + Default: ``'::'``. + postorder (bool): + If ``True``, enforces the sequence is sorted in post-order. Default: ``True``. + + Returns: + A result constituency tree. + + Examples: + >>> from supar.models.const.crf.transform import Tree + >>> Tree.build(['She', 'enjoys', 'playing', 'tennis', '.'], + [(0, 5, 'S'), (0, 4, 'S*'), (0, 1, 'NP'), (1, 4, 'VP'), (1, 2, 'VP*'), + (2, 4, 'S::VP'), (2, 3, 'VP*'), (3, 4, 'NP'), (4, 5, 'S*')], + root='TOP').pretty_print() + TOP + | + S + ____________|________________ + | VP | + | _______|_____ | + | | S | + | | | | + | | VP | + | | _____|____ | + NP | | NP | + | | | | | + _ _ _ _ _ + | | | | | + She enjoys playing tennis . + + >>> Tree.build(['She', 'enjoys', 'playing', 'tennis', '.'], + [(0, 1, 'NP'), (3, 4, 'NP'), (2, 4, 'VP'), (2, 4, 'S'), (1, 4, 'VP'), (0, 5, 'S')], + root='TOP').pretty_print() + TOP + | + S + ____________|________________ + | VP | + | _______|_____ | + | | S | + | | | | + | | VP | + | | _____|____ | + NP | | NP | + | | | | | + _ _ _ _ _ + | | | | | + She enjoys playing tennis . + + """ + + tree = sentence if isinstance(sentence, nltk.Tree) else Tree.totree(sentence, root) + leaves = [subtree for subtree in tree.subtrees() if not isinstance(subtree[0], nltk.Tree)] + if postorder: + spans = sorted(spans, key=lambda x: (x[1], x[1] - x[0])) + + root = tree.label() + start, stack = 0, [] + for span in spans: + i, j, label = span + if delete_labels is not None and label in delete_labels: + continue + stack.extend([(n, n + 1, leaf) for n, leaf in enumerate(leaves[start:i], start)]) + children = [] + while len(stack) > 0 and i <= stack[-1][0]: + children = [stack.pop()] + children + start = children[-1][1] if len(children) > 0 else i + children.extend([(n, n + 1, leaf) for n, leaf in enumerate(leaves[start:j], start)]) + start = j + if not label or label.endswith(mark): + stack.extend(children) + continue + labels = label.split(join) + tree = nltk.Tree(labels[-1], [child[-1] for child in children]) + for label in reversed(labels[:-1]): + tree = nltk.Tree(label, [tree]) + stack.append((i, j, tree)) + stack.extend([(n, n + 1, leaf) for n, leaf in enumerate(leaves[start:], start)]) + return nltk.Tree(root, [i[-1] for i in stack]) + + def load( + self, + data: Union[str, Iterable], + lang: Optional[str] = None, + **kwargs + ) -> List[TreeSentence]: + r""" + Args: + data (Union[str, Iterable]): + A filename or a list of instances. + lang (str): + Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. + ``None`` if tokenization is not required. + Default: ``None``. + + Returns: + A list of :class:`TreeSentence` instances. + """ + + if lang is not None: + tokenizer = Tokenizer(lang) + if isinstance(data, str) and os.path.exists(data): + if data.endswith('.txt'): + data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1) + else: + data = open(data) + else: + if lang is not None: + data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)] + else: + data = [data] if isinstance(data[0], str) else data + + index = 0 + for s in data: + try: + tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root) + sentence = TreeSentence(self, tree, index, **kwargs) + except ValueError: + logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!") + continue + else: + yield sentence + index += 1 + self.root = tree.label() + + +class TreeSentence(Sentence): + r""" + Args: + transform (Tree): + A :class:`Tree` object. + tree (nltk.tree.Tree): + A :class:`nltk.tree.Tree` object. + index (Optional[int]): + Index of the sentence in the corpus. Default: ``None``. + """ + + def __init__( + self, + transform: Tree, + tree: nltk.Tree, + index: Optional[int] = None, + **kwargs + ) -> TreeSentence: + super().__init__(transform, index) + + words, tags, chart = *zip(*tree.pos()), None + if transform.training: + chart = [[None] * (len(words) + 1) for _ in range(len(words) + 1)] + for i, j, label in Tree.factorize(Tree.binarize(tree, implicit=kwargs.get('implicit', False))[0]): + chart[i][j] = label + self.values = [words, tags, tree, chart] + + def __repr__(self): + return self.values[-2].pformat(1000000) + + def pretty_print(self): + self.values[-2].pretty_print() diff --git a/tania_scripts/supar/models/const/sl/__init__.py b/tania_scripts/supar/models/const/sl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b6b7bd1056edd874f9589057530c6fcae8ed25d --- /dev/null +++ b/tania_scripts/supar/models/const/sl/__init__.py @@ -0,0 +1,4 @@ +from .model import SLConstituentModel +from .parser import SLConstituentParser + +__all__ = ['SLConstituentModel', 'SLConstituentParser'] \ No newline at end of file diff --git a/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25ff091642789417721ecf105676e6aa0cec7ae7 Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d46638efe8006ba326a4c8302ee488c3cbf8e14 Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f671041ff9274200b3afc20f09983fcd26145c5b Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dc15de997aec065667819a9092434d52331791f Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d61fb3dc0394c8bb1146467f5c1b762d37cca0b Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdce8977c12fc63d881480ee1b0a5cde8a562fd9 Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1874e521ca0c7d7e3004a1b69a4ae48e1f6bf140 Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbffb6675228c80bd72bea723d3dbbfbcbec55b4 Binary files /dev/null and b/tania_scripts/supar/models/const/sl/__pycache__/transform.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/sl/model.py b/tania_scripts/supar/models/const/sl/model.py new file mode 100644 index 0000000000000000000000000000000000000000..640c1f99b42927dc1c84a4fb23eb3314001eaf23 --- /dev/null +++ b/tania_scripts/supar/models/const/sl/model.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from supar.model import Model +from supar.modules import MLP, DecoderLSTM +from supar.utils import Config +from typing import List + + +class SLConstituentModel(Model): + + def __init__(self, + n_words, + n_commons, + n_ancestors, + n_tags=None, + n_chars=None, + encoder='lstm', + feat: List[str] =['char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, False), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_arc_mlp=500, + n_rel_mlp=100, + mlp_dropout=.33, + scale=0, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + # create decoder + self.common_decoder, self.ancestor_decoder = None, None + if self.args.decoder == 'lstm': + decoder = lambda out_dim: DecoderLSTM( + self.args.n_encoder_hidden, self.args.n_encoder_hidden, out_dim, + self.args.n_decoder_layers, dropout=mlp_dropout, device=self.device + ) + else: + decoder = lambda out_dim: MLP( + n_in=self.args.n_encoder_hidden, n_out=out_dim, + dropout=mlp_dropout, activation=True + ) + + self.common_decoder = decoder(self.args.n_commons) + self.ancestor_decoder = decoder(self.args.n_ancestors) + + # create delay projection + if self.args.delay != 0: + self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay + 1), + n_out=self.args.n_encoder_hidden, dropout=mlp_dropout) + + self.criterion = nn.CrossEntropyLoss() + + def forward(self, words: torch.Tensor, feats: List[torch.Tensor]=None): + # words, *feats ~ [batch_size, bos + pad(seq_len) + delay, n_encoder_hidden] + x = self.encode(words, feats) + x = x[:, 1:, :] + + # x ~ [batch_size, pad(seq_len), n_encoder_hidden] + batch_size, pad_seq_len, embed_size = x.shape + if self.args.delay != 0: + x = torch.cat([x[:, i:(pad_seq_len - self.args.delay + i), :] for i in range(self.args.delay+1)], dim=2) + x = self.delay_proj(x) + + x, qloss = self.vq_forward(x) # vector quantization + + # s_common ~ [batch_size, pad(seq_len), n_commons] + # s_ancestor ~ [batch_size, pad(seq_len), n_ancestors] + s_common, s_ancestor = self.common_decoder(x), self.ancestor_decoder(x) + return s_common, s_ancestor, qloss + + def loss(self, + s_common: torch.Tensor, s_ancestor: torch.Tensor, + commons: torch.Tensor, ancestors: torch.Tensor, mask: torch.Tensor): + s_common, commons = s_common[mask], commons[mask] + s_ancestor, ancestors = s_ancestor[mask], ancestors[mask] + common_loss = self.criterion(s_common, commons) + ancestor_loss = self.criterion(s_ancestor, ancestors) + return common_loss + ancestor_loss + + def decode(self, s_common, s_ancestor): + common_pred = s_common.argmax(-1) + ancestor_pred = s_ancestor.argmax(-1) + return common_pred, ancestor_pred \ No newline at end of file diff --git a/tania_scripts/supar/models/const/sl/parser.py b/tania_scripts/supar/models/const/sl/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..17e8a3a91e4aa862223a846e045b3d442e1e43dd --- /dev/null +++ b/tania_scripts/supar/models/const/sl/parser.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- + +import os +from typing import Dict, Iterable, Set, Union + +import torch, nltk +from supar.models.const.sl.model import SLConstituentModel +from supar.parser import Parser +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, EOS, PAD, UNK +from supar.utils.field import ChartField, Field, RawField, SubwordField +from supar.utils.logging import get_logger +from supar.utils.metric import SpanMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.models.const.sl.transform import SLConstituent +from supar.utils.transform import Batch +from supar.models.const.crf.transform import Tree +from supar.codelin import get_con_encoder, LinearizedTree, C_Label +from supar.codelin.utils.constants import BOS as C_BOS, EOS as C_EOS + +logger = get_logger(__name__) + + +class SLConstituentParser(Parser): + + NAME = 'SLConstituentParser' + MODEL = SLConstituentModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.COMMON = self.transform.COMMON + self.ANCESTOR = self.transform.ANCESTOR + self.encoder = self.transform.encoder + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, texts, *feats, commons, ancestors, trees = batch + mask = batch.mask[:, (1+self.args.delay):] + s_common, s_ancestor, qloss = self.model(words, feats) + loss = self.model.loss(s_common, s_ancestor, commons, ancestors, mask) + qloss + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> SpanMetric: + words, texts, *feats, commons, ancestors, trees = batch + mask = batch.mask[:, (1+self.args.delay):] + + # forward pass + s_common, s_ancestor, qloss = self.model(words, feats) + loss = self.model.loss(s_common, s_ancestor, commons, ancestors, mask) + qloss + + # make predictions of the output of the network + common_preds, ancestor_preds = self.model.decode(s_common, s_ancestor) + + # obtain original tokens to compute decoding + lens = (batch.lens - 1 - self.args.delay).tolist() + common_preds = [self.COMMON.vocab[i.tolist()] for i in common_preds[mask].split(lens)] + ancestor_preds = [self.ANCESTOR.vocab[i.tolist()] for i in ancestor_preds[mask].split(lens)] + # tag_preds = [self.transform.POS.vocab[i.tolist()] for i in tags[mask].split(lens)] + tag_preds = map(lambda tree: tuple(zip(*tree.pos()))[1], trees) + + preds = list() + for i, (forms, upos, common_pred, ancestor_pred) in enumerate(zip(texts, tag_preds, common_preds, ancestor_preds)): + labels = list(map(lambda x: self.encoder.separator.join(x), zip(common_pred, ancestor_pred))) + linearized = list(map(lambda x: '\t'.join(x), zip(forms, upos, labels))) + linearized = [f'{C_BOS}\t{C_BOS}\t{C_BOS}'] + linearized + [f'{C_EOS}\t{C_EOS}\t{C_EOS}'] + linearized = '\n'.join(linearized) + tree = LinearizedTree.from_string(linearized, mode='CONST', separator=self.encoder.separator, + unary_joiner=self.encoder.unary_joiner) + tree = self.encoder.decode(tree) + tree = tree.postprocess_tree('strat_max', clean_nulls=False) + preds.append(nltk.Tree.fromstring(str(tree))) + + if len(preds[-1].leaves()) != len(trees[i].leaves()): + with open('error', 'w') as file: + file.write(linearized) + + + return SpanMetric(loss, + [Tree.factorize(tree, None, None) for tree in preds], + [Tree.factorize(tree, None, None) for tree in trees]) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, texts, *feats = batch + tags = feats[-1][:, (1+self.args.delay):] + mask = batch.mask[:, (1 + self.args.delay):] + + # forward pass + s_common, s_ancestor, _ = self.model(words, feats) + + # make predictions of the output of the network + common_preds, ancestor_preds = self.model.decode(s_common, s_ancestor) + + # obtain original tokens to compute decoding + lens = (batch.lens - 1 - self.args.delay).tolist() + common_preds = [self.COMMON.vocab[i.tolist()] for i in common_preds[mask].split(lens)] + ancestor_preds = [self.ANCESTOR.vocab[i.tolist()] for i in ancestor_preds[mask].split(lens)] + tag_preds = [self.transform.POS.vocab[i.tolist()] for i in tags[mask].split(lens)] + + preds = list() + for i, (forms, upos, common_pred, ancestor_pred) in enumerate(zip(texts, tag_preds, common_preds, ancestor_preds)): + labels = list(map(lambda x: self.encoder.separator.join(x), zip(common_pred, ancestor_pred))) + linearized = list(map(lambda x: '\t'.join(x), zip(forms, upos, labels))) + linearized = [f'{C_BOS}\t{C_BOS}\t{C_BOS}'] + linearized + [f'{C_EOS}\t{C_EOS}\t{C_EOS}'] + linearized = '\n'.join(linearized) + tree = LinearizedTree.from_string(linearized, mode='CONST', separator=self.encoder.separator, unary_joiner=self.encoder.unary_joiner) + tree = self.encoder.decode(tree) + tree = tree.postprocess_tree('strat_max', clean_nulls=False) + preds.append(nltk.Tree.fromstring(str(tree))) + batch.trees = preds + return batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + TAG, CHAR = None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + pad_token = t.pad if t.pad else PAD + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t, delay=args.delay) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True, delay=args.delay) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len, delay=args.delay) + TAG = Field('tags', bos=BOS,) + TEXT = RawField('texts') + COMMON = Field('commons') + ANCESTOR = Field('ancestors') + TREE = RawField('trees') + encoder = get_con_encoder(args.codes) + transform = SLConstituent(encoder=encoder, WORD=(WORD, TEXT, CHAR), POS=TAG, + COMMON=COMMON, ANCESTOR=ANCESTOR, TREE=TREE) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if CHAR is not None: + CHAR.build(train) + TAG.build(train) + COMMON.build(train) + ANCESTOR.build(train) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_commons': len(COMMON.vocab), + 'n_ancestors': len(ANCESTOR.vocab), + 'n_tags': len(TAG.vocab), + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'delay': 0 if 'delay' not in args.keys() else args.delay, + }) + logger.info(f"{transform}") + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser diff --git a/tania_scripts/supar/models/const/sl/transform.py b/tania_scripts/supar/models/const/sl/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b94b9660d42e20953d4a031e96aa363f80301357 --- /dev/null +++ b/tania_scripts/supar/models/const/sl/transform.py @@ -0,0 +1,94 @@ +from typing import Union, Iterable, Optional, Set +from supar.utils.field import Field +from supar.utils.transform import Sentence, Transform +from supar.codelin import C_Tree, UNARY_JOINER, LABEL_SEPARATOR +import nltk, re + + +class SLConstituent(Transform): + fields = ['WORD', 'POS', 'COMMON', 'ANCESTOR', 'TREE'] + + def __init__( + self, + encoder, + WORD: Union[Field, Iterable[Field]], + POS: Union[Field, Iterable[Field]], + COMMON: Union[Field, Iterable[Field]], + ANCESTOR: Union[Field, Iterable[Field]], + TREE: Union[Field, Iterable[Field]] + ): + super().__init__() + + self.WORD = WORD + self.POS = POS + self.COMMON = COMMON + self.ANCESTOR = ANCESTOR + self.TREE = TREE + self.encoder = encoder + + + @property + def src(self): + return self.WORD, self.POS + + @property + def tgt(self): + return self.COMMON, self.ANCESTOR, self.TREE + def load( + self, + data: Union[str, Iterable], + **kwargs + ) -> Iterable[Sentence]: + lines = open(data) + index = 0 + for line in lines: + line = line.strip() + if len(line) > 0: + sentence = SLConstituentSentence(self, line, self.encoder, index) + yield sentence + index += 1 + + +def get_nodes(tree: nltk.Tree): + nodes = {tree.label()} + for subtree in tree: + if isinstance(subtree[0], nltk.Tree): + nodes = {*nodes, *get_nodes(subtree)} + return nodes + +def remove_doubles(tree: nltk.Tree): + for i, subtree in enumerate(tree): + if isinstance(subtree, str) and len(tree) > 1: + tree.pop(i) + elif isinstance(subtree, nltk.Tree): + remove_doubles(subtree) + +class SLConstituentSentence(Sentence): + + def __init__(self, transform: SLConstituent, line: str, encoder, index: Optional[int] = None): + super().__init__(transform, index) + + # get nodes of the tree + gold = nltk.Tree.fromstring(line) + nodes = ''.join(get_nodes(gold)) + assert (encoder.separator not in nodes) and (encoder.unary_joiner not in nodes) + + # create linearized tree + tree = C_Tree.from_string(line) + linearized_tree = encoder.encode(tree) + self.annotations = [] + commons = list(map(lambda x: repr(x).split(encoder.separator)[0], linearized_tree.labels)) + ancestors = list(map(lambda x: encoder.separator.join(repr(x).split(encoder.separator)[1:]), linearized_tree.labels)) + + _, postags = zip(*gold.pos()) + self.values = [ + linearized_tree.words, + postags, + commons, ancestors + ] + self.values.append( + nltk.Tree.fromstring(line) + ) + def __repr__(self): + remove_doubles(self.values[-1]) + return re.sub(' +', ' ', str(self.values[-1]).replace('\n', '').replace('\t', '')) diff --git a/tania_scripts/supar/models/const/tt/__init__.py b/tania_scripts/supar/models/const/tt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43892195ca6b97953149e492744adacaf5f01015 --- /dev/null +++ b/tania_scripts/supar/models/const/tt/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import TetraTaggingConstituencyModel +from .parser import TetraTaggingConstituencyParser + +__all__ = ['TetraTaggingConstituencyModel', 'TetraTaggingConstituencyParser'] diff --git a/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a577e63e231c15d7c6ffbde1d8660f3f458b985 Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f242de08b1629939a1184b1ebf706bb4b2b2cbda Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4deac65c29f9618638ef3d168e5f3799e148ae96 Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08a12831d12bab671d404d55859b4bb81ba4c6e0 Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..740bf37c4d0abed35cbd2fe734b4899404da1c63 Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d0f9c3cecab729455b1cf901cfed200de9a4883 Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2226fc4c38750791963263a693f1896078972e05 Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee00f85bb15eed4b144ed777f9dbfeef8c753e8c Binary files /dev/null and b/tania_scripts/supar/models/const/tt/__pycache__/transform.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/tt/model.py b/tania_scripts/supar/models/const/tt/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9bce1b6489bef08ba66494937a7a9c64873582ec --- /dev/null +++ b/tania_scripts/supar/models/const/tt/model.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- + +from typing import List, Tuple + +import torch +import torch.nn as nn + +from supar.model import Model +from supar.utils import Config +from supar.utils.common import INF + + +class TetraTaggingConstituencyModel(Model): + r""" + The implementation of TetraTagging Constituency Parser :cite:`kitaev-klein-2020-tetra`. + + Args: + n_words (int): + The size of the word vocabulary. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layers. Default: .33. + n_gnn_layers (int): + The number of GNN layers. Default: 3. + gnn_dropout (float): + The dropout ratio of GNN layers. Default: .33. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_tags=None, + n_chars=None, + encoder='lstm', + feat=['char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, True), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_gnn_layers=3, + gnn_dropout=.33, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + self.proj = nn.Linear(self.args.n_encoder_hidden, self.args.n_leaves + self.args.n_nodes) + self.criterion = nn.CrossEntropyLoss() + + def forward( + self, + words: torch.LongTensor, + feats: List[torch.LongTensor] = None + ) -> torch.Tensor: + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + Scores for all leaves (``[batch_size, seq_len, n_leaves]``) and nodes (``[batch_size, seq_len, n_nodes]``). + """ + + s = self.proj(self.encode(words, feats)[:, 1:-1]) + s_leaf, s_node = s[..., :self.args.n_leaves], s[..., self.args.n_leaves:] + return s_leaf, s_node + + def loss( + self, + s_leaf: torch.Tensor, + s_node: torch.Tensor, + leaves: torch.LongTensor, + nodes: torch.LongTensor, + mask: torch.BoolTensor + ) -> torch.Tensor: + r""" + Args: + s_leaf (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``. + Leaf scores. + s_node (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``. + Non-terminal scores. + leaves (~torch.LongTensor): ``[batch_size, seq_len]``. + Actions for leaves. + nodes (~torch.LongTensor): ``[batch_size, seq_len]``. + Actions for non-terminals. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + + Returns: + ~torch.Tensor: + The training loss. + """ + + leaf_mask, node_mask = mask, mask[:, 1:] + leaf_loss = self.criterion(s_leaf[leaf_mask], leaves[leaf_mask]) + node_loss = self.criterion(s_node[:, :-1][node_mask], nodes[node_mask]) if nodes.shape[1] > 0 else 0 + return leaf_loss + node_loss + + def decode( + self, + s_leaf: torch.Tensor, + s_node: torch.Tensor, + mask: torch.BoolTensor, + left_mask: torch.BoolTensor, + depth: int = 8 + ) -> List[List[Tuple]]: + r""" + Args: + s_leaf (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``. + Leaf scores. + s_node (~torch.Tensor): ``[batch_size, seq_len, n_leaves]``. + Non-terminal scores. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + left_mask (~torch.BoolTensor): ``[n_leaves + n_nodes]``. + The mask for distingushing left/rightward actions. + depth (int): + Stack depth. Default: 8. + + Returns: + List[List[Tuple]]: + Sequences of factorized labeled trees. + """ + from torch_scatter import scatter_max + + lens = mask.sum(-1) + batch_size, seq_len, n_leaves = s_leaf.shape + end_mask = (lens - 1).unsqueeze(-1).eq(lens.new_tensor(range(seq_len))) + leaf_left_mask, node_left_mask = left_mask[:n_leaves], left_mask[n_leaves:] + s_leaf = s_leaf.masked_fill_(end_mask.unsqueeze(-1) & leaf_left_mask, -INF) + # [n_leaves], [n_nodes] + changes = (torch.where(leaf_left_mask, 1, 0), torch.where(node_left_mask, 0, -1)) + # [batch_size, depth] + depths = lens.new_full((depth,), -2).index_fill_(-1, lens.new_tensor(0), -1).repeat(batch_size, 1) + # [2, batch_size, depth, seq_len] + labels, paths = lens.new_zeros(2, batch_size, depth, seq_len), lens.new_zeros(2, batch_size, depth, seq_len) + # [batch_size, depth] + s = s_leaf.new_zeros(batch_size, depth) + + def advance(s, s_t, depths, changes): + batch_size, n_labels = s_t.shape + # [batch_size, depth * n_labels] + depths = (depths.unsqueeze(-1) + changes).view(batch_size, -1) + # [batch_size, depth, n_labels] + s_t = s.unsqueeze(-1) + s_t.unsqueeze(1) + # [batch_size, depth * n_labels] + # fill scores of invalid depths with -INF + s_t = s_t.view(batch_size, -1).masked_fill_((depths < 0).logical_or_(depths >= depth), -INF) + # [batch_size, depth] + # for each depth, we use the `scatter_max` trick to obtain the 1-best label + s, ls = scatter_max(s_t, depths.clamp(0, depth - 1), -1, s_t.new_full((batch_size, depth), -INF)) + # [batch_size, depth] + depths = depths.gather(-1, ls.clamp(0, depths.shape[1] - 1)).masked_fill_(s.eq(-INF), -1) + ll = ls % n_labels + lp = depths - changes[ll] + return s, ll, lp, depths + + for t in range(seq_len): + m = lens.gt(t) + s[m], labels[0, m, :, t], paths[0, m, :, t], depths[m] = advance(s[m], s_leaf[m, t], depths[m], changes[0]) + if t == seq_len - 1: + break + m = lens.gt(t + 1) + s[m], labels[1, m, :, t], paths[1, m, :, t], depths[m] = advance(s[m], s_node[m, t], depths[m], changes[1]) + + lens = lens.tolist() + labels, paths = labels.movedim((0, 2), (2, 3))[mask].split(lens), paths.movedim((0, 2), (2, 3))[mask].split(lens) + leaves, nodes = [], [] + for i, length in enumerate(lens): + leaf_labels, node_labels = labels[i].transpose(0, 1).tolist() + leaf_paths, node_paths = paths[i].transpose(0, 1).tolist() + leaf_pred, node_pred, prev = [leaf_labels[-1][0]], [], leaf_paths[-1][0] + for j in reversed(range(length - 1)): + node_pred.append(node_labels[j][prev]) + prev = node_paths[j][prev] + leaf_pred.append(leaf_labels[j][prev]) + prev = leaf_paths[j][prev] + leaves.append(list(reversed(leaf_pred))) + nodes.append(list(reversed(node_pred))) + return leaves, nodes diff --git a/tania_scripts/supar/models/const/tt/parser.py b/tania_scripts/supar/models/const/tt/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..7289c222a454629876ecc155721fef36ef79a0fb --- /dev/null +++ b/tania_scripts/supar/models/const/tt/parser.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- + +import os +from typing import Dict, Iterable, Set, Union + +import torch + +from supar.models.const.tt.model import TetraTaggingConstituencyModel +from supar.models.const.tt.transform import TetraTaggingTree +from supar.parser import Parser +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, EOS, PAD, UNK +from supar.utils.field import Field, RawField, SubwordField +from supar.utils.logging import get_logger +from supar.utils.metric import SpanMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class TetraTaggingConstituencyParser(Parser): + r""" + The implementation of TetraTagging Constituency Parser :cite:`kitaev-klein-2020-tetra`. + """ + + NAME = 'tetra-tagging-constituency' + MODEL = TetraTaggingConstituencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.TREE = self.transform.TREE + self.LEAF = self.transform.LEAF + self.NODE = self.transform.NODE + + self.left_mask = torch.tensor([*(i.startswith('l') for i in self.LEAF.vocab.itos), + *(i.startswith('L') for i in self.NODE.vocab.itos)]).to(self.device) + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + depth: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + depth: int = 1, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + depth: int = 1, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, *feats, _, leaves, nodes = batch + mask = batch.mask[:, 2:] + s_leaf, s_node = self.model(words, feats) + loss = self.model.loss(s_leaf, s_node, leaves, nodes, mask) + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> SpanMetric: + words, *feats, trees, leaves, nodes = batch + mask = batch.mask[:, 2:] + s_leaf, s_node = self.model(words, feats) + loss = self.model.loss(s_leaf, s_node, leaves, nodes, mask) + preds = self.model.decode(s_leaf, s_node, mask, self.left_mask, self.args.depth) + preds = [TetraTaggingTree.action2tree(tree, (self.LEAF.vocab[i], self.NODE.vocab[j] if len(j) > 0 else [])) + for tree, i, j in zip(trees, *preds)] + return SpanMetric(loss, + [TetraTaggingTree.factorize(tree, self.args.delete, self.args.equal) for tree in preds], + [TetraTaggingTree.factorize(tree, self.args.delete, self.args.equal) for tree in trees]) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, *feats, trees = batch + mask = batch.mask[:, 2:] + s_leaf, s_node = self.model(words, feats) + preds = self.model.decode(s_leaf, s_node, mask, self.left_mask, self.args.depth) + batch.trees = [TetraTaggingTree.action2tree(tree, (self.LEAF.vocab[i], self.NODE.vocab[j] if len(j) > 0 else [])) + for tree, i, j in zip(trees, *preds)] + if self.args.prob: + raise NotImplementedError("Returning action probs are currently not supported yet.") + return batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + r""" + Build a brand-new Parser, including initialization of all data fields and model parameters. + + Args: + path (str): + The path of the model to be saved. + min_freq (str): + The minimum frequency needed to include a token in the vocabulary. Default: 2. + fix_len (int): + The max length of all subword pieces. The excess part of each piece will be truncated. + Required if using CharLSTM/BERT. + Default: 20. + kwargs (Dict): + A dict holding the unconsumed arguments. + """ + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.WORD[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + TAG, CHAR, ELMO, BERT = None, None, None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True) + if 'tag' in args.feat: + TAG = Field('tags', bos=BOS, eos=EOS) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len) + if 'elmo' in args.feat: + from allennlp.modules.elmo import batch_to_ids + ELMO = RawField('elmo') + ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device) + if 'bert' in args.feat: + t = TransformerTokenizer(args.bert) + BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t) + BERT.vocab = t.vocab + TREE = RawField('trees') + LEAF, NODE = Field('leaf'), Field('node') + transform = TetraTaggingTree(WORD=(WORD, CHAR, ELMO, BERT), POS=TAG, TREE=TREE, LEAF=LEAF, NODE=NODE) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if TAG is not None: + TAG.build(train) + if CHAR is not None: + CHAR.build(train) + LEAF, NODE = LEAF.build(train), NODE.build(train) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_leaves': len(LEAF.vocab), + 'n_nodes': len(NODE.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'bert_pad_index': BERT.pad_index if BERT is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index, + 'eos_index': WORD.eos_index + }) + logger.info(f"{transform}") + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser diff --git a/tania_scripts/supar/models/const/tt/transform.py b/tania_scripts/supar/models/const/tt/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..a337b94eaaab48f7a61568a4a396604ffe839691 --- /dev/null +++ b/tania_scripts/supar/models/const/tt/transform.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union, Sequence + +import nltk + +from supar.models.const.crf.transform import Tree +from supar.utils.logging import get_logger +from supar.utils.tokenizer import Tokenizer +from supar.utils.transform import Sentence + +if TYPE_CHECKING: + from supar.utils import Field + +logger = get_logger(__name__) + + +class TetraTaggingTree(Tree): + r""" + :class:`TetraTaggingTree` is derived from the :class:`Tree` class and is defined for supporting the transition system of + tetra tagger :cite:`kitaev-klein-2020-tetra`. + + Attributes: + WORD: + Words in the sentence. + POS: + Part-of-speech tags, or underscores if not available. + TREE: + The raw constituency tree in :class:`nltk.tree.Tree` format. + LEAF: + Action labels in tetra tagger transition system. + NODE: + Non-terminal labels. + """ + + fields = ['WORD', 'POS', 'TREE', 'LEAF', 'NODE'] + + def __init__( + self, + WORD: Optional[Union[Field, Iterable[Field]]] = None, + POS: Optional[Union[Field, Iterable[Field]]] = None, + TREE: Optional[Union[Field, Iterable[Field]]] = None, + LEAF: Optional[Union[Field, Iterable[Field]]] = None, + NODE: Optional[Union[Field, Iterable[Field]]] = None + ) -> Tree: + super().__init__() + + self.WORD = WORD + self.POS = POS + self.TREE = TREE + self.LEAF = LEAF + self.NODE = NODE + + @property + def tgt(self): + return self.LEAF, self.NODE + + @classmethod + def tree2action(cls, tree: nltk.Tree) -> Tuple[Sequence, Sequence]: + r""" + Converts a (binarized) constituency tree into tetra-tagging actions. + + Args: + tree (nltk.tree.Tree): + A constituency tree in :class:`nltk.tree.Tree` format. + + Returns: + Tetra-tagging actions for leaves and non-terminals. + + Examples: + >>> from supar.models.const.tt.transform import TetraTaggingTree + >>> tree = nltk.Tree.fromstring(''' + (TOP + (S + (NP (_ She)) + (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis))))) + (_ .))) + ''') + >>> tree.pretty_print() + TOP + | + S + ____________|________________ + | VP | + | _______|_____ | + | | S | + | | | | + | | VP | + | | _____|____ | + NP | | NP | + | | | | | + _ _ _ _ _ + | | | | | + She enjoys playing tennis . + + >>> tree = TetraTaggingTree.binarize(tree, left=False, implicit=True) + >>> tree.pretty_print() + TOP + | + S + ____________|______ + | + | ______|___________ + | VP | + | _______|______ | + | | S::VP | + | | ______|_____ | + NP NP + | | | | | + _ _ _ _ _ + | | | | | + She enjoys playing tennis . + + >>> TetraTaggingTree.tree2action(tree) + (['l/NP', 'l/', 'l/', 'r/NP', 'r/'], ['L/S', 'L/VP', 'R/S::VP', 'R/']) + """ + + def traverse(tree: nltk.Tree, left: bool = True) -> List: + if len(tree) == 1 and not isinstance(tree[0], nltk.Tree): + return ['l' if left else 'r'], [] + if len(tree) == 1 and not isinstance(tree[0][0], nltk.Tree): + return [f"{'l' if left else 'r'}/{tree.label()}"], [] + return tuple(sum(i, []) for i in zip(*[traverse(tree[0]), + ([], [f'{("L" if left else "R")}/{tree.label()}']), + traverse(tree[1], False)])) + return traverse(tree[0]) + + @classmethod + def action2tree( + cls, + tree: nltk.Tree, + actions: Tuple[Sequence, Sequence], + mark: Union[str, Tuple[str]] = ('*', '|<>'), + join: str = '::', + ) -> nltk.Tree: + r""" + Recovers a constituency tree from tetra-tagging actions. + + Args: + tree (nltk.tree.Tree): + An empty tree that provides a base for building a result tree. + actions (Tuple[Sequence, Sequence]): + Tetra-tagging actions. + mark (Union[str, List[str]]): + A string used to mark newly inserted nodes. Non-terminals containing this will be removed. + Default: ``('*', '|<>')``. + join (str): + A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains. + Default: ``'::'``. + + Returns: + A result constituency tree. + + Examples: + >>> from supar.models.const.tt.transform import TetraTaggingTree + >>> tree = TetraTaggingTree.totree(['She', 'enjoys', 'playing', 'tennis', '.'], 'TOP') + >>> actions = (['l/NP', 'l/', 'l/', 'r/NP', 'r/'], ['L/S', 'L/VP', 'R/S::VP', 'R/']) + >>> TetraTaggingTree.action2tree(tree, actions).pretty_print() + TOP + | + S + ____________|________________ + | VP | + | _______|_____ | + | | S | + | | | | + | | VP | + | | _____|____ | + NP | | NP | + | | | | | + _ _ _ _ _ + | | | | | + She enjoys playing tennis . + + """ + + stack = [] + leaves = [nltk.Tree(pos, [token]) for token, pos in tree.pos()] + for i, (al, an) in enumerate(zip(*actions)): + leaf = nltk.Tree(al.split('/', 1)[1], [leaves[i]]) + if al.startswith('l'): + stack.append([leaf, None]) + else: + slot = stack[-1][1] + slot.append(leaf) + if an.startswith('L'): + node = nltk.Tree(an.split('/', 1)[1], [stack[-1][0]]) + stack[-1][0] = node + else: + node = nltk.Tree(an.split('/', 1)[1], [stack.pop()[0]]) + slot = stack[-1][1] + slot.append(node) + stack[-1][1] = node + # the last leaf must be leftward + leaf = nltk.Tree(actions[0][-1].split('/', 1)[1], [leaves[-1]]) + if len(stack) > 0: + stack[-1][1].append(leaf) + else: + stack.append([leaf, None]) + + def debinarize(tree): + if len(tree) == 1 and not isinstance(tree[0], nltk.Tree): + return [tree] + label, children = tree.label(), [] + for child in tree: + children.extend(debinarize(child)) + if not label or label.endswith(mark): + return children + labels = label.split(join) if join in label else [label] + tree = nltk.Tree(labels[-1], children) + for label in reversed(labels[:-1]): + tree = nltk.Tree(label, [tree]) + return [tree] + return debinarize(nltk.Tree(tree.label(), [stack[0][0]]))[0] + + def load( + self, + data: Union[str, Iterable], + lang: Optional[str] = None, + **kwargs + ) -> List[TetraTaggingTreeSentence]: + r""" + Args: + data (Union[str, Iterable]): + A filename or a list of instances. + lang (str): + Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. + ``None`` if tokenization is not required. + Default: ``None``. + + Returns: + A list of :class:`TetraTaggingTreeSentence` instances. + """ + + if lang is not None: + tokenizer = Tokenizer(lang) + if isinstance(data, str) and os.path.exists(data): + if data.endswith('.txt'): + data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1) + else: + data = open(data) + else: + if lang is not None: + data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)] + else: + data = [data] if isinstance(data[0], str) else data + + index = 0 + for s in data: + try: + tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root) + sentence = TetraTaggingTreeSentence(self, tree, index) + except ValueError: + logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!") + continue + else: + yield sentence + index += 1 + self.root = tree.label() + + +class TetraTaggingTreeSentence(Sentence): + r""" + Args: + transform (TetraTaggingTree): + A :class:`TetraTaggingTree` object. + tree (nltk.tree.Tree): + A :class:`nltk.tree.Tree` object. + index (Optional[int]): + Index of the sentence in the corpus. Default: ``None``. + """ + + def __init__( + self, + transform: TetraTaggingTree, + tree: nltk.Tree, + index: Optional[int] = None + ) -> TetraTaggingTreeSentence: + super().__init__(transform, index) + + words, tags = zip(*tree.pos()) + leaves, nodes = None, None + if transform.training: + oracle_tree = tree.copy(True) + # the root node must have a unary chain + if len(oracle_tree) > 1: + oracle_tree[:] = [nltk.Tree('*', oracle_tree)] + oracle_tree = TetraTaggingTree.binarize(oracle_tree, left=False, implicit=True) + if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree): + oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]]) + leaves, nodes = transform.tree2action(oracle_tree) + self.values = [words, tags, tree, leaves, nodes] + + def __repr__(self): + return self.values[-3].pformat(1000000) + + def pretty_print(self): + self.values[-3].pretty_print() diff --git a/tania_scripts/supar/models/const/vi/__init__.py b/tania_scripts/supar/models/const/vi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db91608946d018e5adb1dca9e49acb144fc7eb31 --- /dev/null +++ b/tania_scripts/supar/models/const/vi/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import VIConstituencyModel +from .parser import VIConstituencyParser + +__all__ = ['VIConstituencyModel', 'VIConstituencyParser'] diff --git a/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..061ba0053d502449b355ff1e49fecf53cd5b85e9 Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5780eac044e7f23e6effee39efea6913adc7c972 Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74c5b353f48c7a6aaa1d1f8944496ff7070b89ec Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef42262709e0f807a94447a265c803f5688b9c97 Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aec5892d74be8c41b6e936b63abef924b1a5841 Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2279ac265980ece4415c326074e52e8ec8ef509 Binary files /dev/null and b/tania_scripts/supar/models/const/vi/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/const/vi/model.py b/tania_scripts/supar/models/const/vi/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c44daac95be0d324e239a2fd5ad6d791a0be64ea --- /dev/null +++ b/tania_scripts/supar/models/const/vi/model.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from supar.models.const.crf.model import CRFConstituencyModel +from supar.modules import MLP, Biaffine, Triaffine +from supar.structs import ConstituencyCRF, ConstituencyLBP, ConstituencyMFVI +from supar.utils import Config + + +class VIConstituencyModel(CRFConstituencyModel): + r""" + The implementation of Constituency Parser using variational inference. + + Args: + n_words (int): + The size of the word vocabulary. + n_labels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layer. Default: .33. + n_span_mlp (int): + Span MLP size. Default: 500. + n_pair_mlp (int): + Binary factor MLP size. Default: 100. + n_label_mlp (int): + Label MLP size. Default: 100. + mlp_dropout (float): + The dropout ratio of MLP layers. Default: .33. + inference (str): + Approximate inference methods. Default: ``mfvi``. + max_iter (int): + Max iteration times for inference. Default: 3. + interpolation (int): + Constant to even out the label/edge loss. Default: .1. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_labels, + n_tags=None, + n_chars=None, + encoder='lstm', + feat=['char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, True), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_span_mlp=500, + n_pair_mlp=100, + n_label_mlp=100, + mlp_dropout=.33, + inference='mfvi', + max_iter=3, + interpolation=0.1, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + self.span_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout) + self.span_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_span_mlp, dropout=mlp_dropout) + self.pair_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=mlp_dropout) + self.pair_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=mlp_dropout) + self.pair_mlp_b = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=mlp_dropout) + self.label_mlp_l = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout) + self.label_mlp_r = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=mlp_dropout) + + self.span_attn = Biaffine(n_in=n_span_mlp, bias_x=True, bias_y=False) + self.pair_attn = Triaffine(n_in=n_pair_mlp, bias_x=True, bias_y=False) + self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True) + self.inference = (ConstituencyMFVI if inference == 'mfvi' else ConstituencyLBP)(max_iter) + self.criterion = nn.CrossEntropyLoss() + + def forward(self, words, feats): + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + + Returns: + ~torch.Tensor, ~torch.Tensor, ~torch.Tensor: + Scores of all possible constituents (``[batch_size, seq_len, seq_len]``), + second-order triples (``[batch_size, seq_len, seq_len, n_labels]``) and + all possible labels on each constituent (``[batch_size, seq_len, seq_len, n_labels]``). + """ + + x = self.encode(words, feats) + + x_f, x_b = x.chunk(2, -1) + x = torch.cat((x_f[:, :-1], x_b[:, 1:]), -1) + + span_l = self.span_mlp_l(x) + span_r = self.span_mlp_r(x) + pair_l = self.pair_mlp_l(x) + pair_r = self.pair_mlp_r(x) + pair_b = self.pair_mlp_b(x) + label_l = self.label_mlp_l(x) + label_r = self.label_mlp_r(x) + + # [batch_size, seq_len, seq_len] + s_span = self.span_attn(span_l, span_r) + s_pair = self.pair_attn(pair_l, pair_r, pair_b).permute(0, 3, 1, 2) + # [batch_size, seq_len, seq_len, n_labels] + s_label = self.label_attn(label_l, label_r).permute(0, 2, 3, 1) + + return s_span, s_pair, s_label + + def loss(self, s_span, s_pair, s_label, charts, mask): + r""" + Args: + s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all constituents. + s_pair (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. + Scores of second-order triples. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all constituent labels. + charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. + The tensor of gold-standard labels. Positions without labels are filled with -1. + mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + The mask for covering the unpadded tokens in each chart. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The training loss and marginals of shape ``[batch_size, seq_len, seq_len]``. + """ + + span_mask = charts.ge(0) & mask + span_loss, span_probs = self.inference((s_span, s_pair), mask, span_mask) + label_loss = self.criterion(s_label[span_mask], charts[span_mask]) + loss = self.args.interpolation * label_loss + (1 - self.args.interpolation) * span_loss + + return loss, span_probs + + def decode(self, s_span, s_label, mask): + r""" + Args: + s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all constituents. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all constituent labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + The mask for covering the unpadded tokens in each chart. + + Returns: + List[List[Tuple]]: + Sequences of factorized labeled trees. + """ + + span_preds = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).argmax + label_preds = s_label.argmax(-1).tolist() + return [[(i, j, labels[i][j]) for i, j in spans] for spans, labels in zip(span_preds, label_preds)] diff --git a/tania_scripts/supar/models/const/vi/parser.py b/tania_scripts/supar/models/const/vi/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..5721a9dcc78f5c7e16484e4c18bc25fed51d81fd --- /dev/null +++ b/tania_scripts/supar/models/const/vi/parser.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Iterable, Set, Union + +import torch + +from supar.models.const.crf.parser import CRFConstituencyParser +from supar.models.const.crf.transform import Tree +from supar.models.const.vi.model import VIConstituencyModel +from supar.utils import Config +from supar.utils.logging import get_logger +from supar.utils.metric import SpanMetric +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class VIConstituencyParser(CRFConstituencyParser): + r""" + The implementation of Constituency Parser using variational inference. + """ + + NAME = 'vi-constituency' + MODEL = VIConstituencyModel + + def train( + self, + train, + dev, + test, + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, workers: int = 0, amp: bool = False, cache: bool = False, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + delete: Set = {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}, + equal: Dict = {'ADVP': 'PRT'}, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, *feats, _, charts = batch + mask = batch.mask[:, 1:] + mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) + s_span, s_pair, s_label = self.model(words, feats) + loss, _ = self.model.loss(s_span, s_pair, s_label, charts, mask) + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> SpanMetric: + words, *feats, trees, charts = batch + mask = batch.mask[:, 1:] + mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) + s_span, s_pair, s_label = self.model(words, feats) + loss, s_span = self.model.loss(s_span, s_pair, s_label, charts, mask) + chart_preds = self.model.decode(s_span, s_label, mask) + preds = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) + for tree, chart in zip(trees, chart_preds)] + return SpanMetric(loss, + [Tree.factorize(tree, self.args.delete, self.args.equal) for tree in preds], + [Tree.factorize(tree, self.args.delete, self.args.equal) for tree in trees]) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, *feats, trees = batch + mask, lens = batch.mask[:, 1:], batch.lens - 2 + mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) + s_span, s_pair, s_label = self.model(words, feats) + s_span = self.model.inference((s_span, s_pair), mask) + chart_preds = self.model.decode(s_span, s_label, mask) + batch.trees = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) + for tree, chart in zip(trees, chart_preds)] + if self.args.prob: + batch.probs = [prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)] + return batch diff --git a/tania_scripts/supar/models/dep/__init__.py b/tania_scripts/supar/models/dep/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aad535787edd8e64b6a673c79f969d10f3133a39 --- /dev/null +++ b/tania_scripts/supar/models/dep/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +from .biaffine import BiaffineDependencyModel, BiaffineDependencyParser +from .crf import CRFDependencyModel, CRFDependencyParser +from .crf2o import CRF2oDependencyModel, CRF2oDependencyParser +from .vi import VIDependencyModel, VIDependencyParser +from .sl import SLDependencyModel, SLDependencyParser +from .eager import ArcEagerDependencyModel, ArcEagerDependencyParser + +__all__ = ['BiaffineDependencyModel', 'BiaffineDependencyParser', + 'CRFDependencyModel', 'CRFDependencyParser', + 'CRF2oDependencyModel', 'CRF2oDependencyParser', + 'VIDependencyModel', 'VIDependencyParser', + 'SLDependencyModel', 'SLDependencyParser', + 'ArcEagerDependencyModel', 'ArcEagerDependencyParser'] diff --git a/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7861152ee2c2a8276035f9246170cd9dc9cc615c Binary files /dev/null and b/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e046f387fd327180d2c8a0e599fbf49639f9884c Binary files /dev/null and b/tania_scripts/supar/models/dep/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/biaffine/__init__.py b/tania_scripts/supar/models/dep/biaffine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d757c65a2c15d8e495528d9029ad34440eee6ebc --- /dev/null +++ b/tania_scripts/supar/models/dep/biaffine/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import BiaffineDependencyModel +from .parser import BiaffineDependencyParser + +__all__ = ['BiaffineDependencyModel', 'BiaffineDependencyParser'] diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a03c15f409c6cb82e0c38e8de1c728163143075d Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cc85e00431ce55159ade848598530026f4c870f Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c463ffc47a61e67f929c1d390efdc163db1ecc9 Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec57b786c69963f7b6418e915051312869bea45f Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a161c0ca80e6f499166fc7e4f195252909a9dc8 Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41daea0ef2c2181fcba2b5fbf9a540eded84891b Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a172f6c96a8bba92e2b94ef264d11b820f66f204 Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b880c597dadc460d71c38295eb68e34d01977a2 Binary files /dev/null and b/tania_scripts/supar/models/dep/biaffine/__pycache__/transform.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/biaffine/model.py b/tania_scripts/supar/models/dep/biaffine/model.py new file mode 100644 index 0000000000000000000000000000000000000000..14420cb8563dbbdd606afad43b6f699c293d2d68 --- /dev/null +++ b/tania_scripts/supar/models/dep/biaffine/model.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from supar.model import Model +from supar.models.dep.biaffine.transform import CoNLL +from supar.modules import MLP, Biaffine +from supar.structs import DependencyCRF, MatrixTree +from supar.utils import Config +from supar.utils.common import MIN + + +class BiaffineDependencyModel(Model): + r""" + The implementation of Biaffine Dependency Parser :cite:`dozat-etal-2017-biaffine`. + + Args: + n_words (int): + The size of the word vocabulary. + n_rels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layer. Default: .33. + n_arc_mlp (int): + Arc MLP size. Default: 500. + n_rel_mlp (int): + Label MLP size. Default: 100. + mlp_dropout (float): + The dropout ratio of MLP layers. Default: .33. + scale (float): + Scaling factor for affine scores. Default: 0. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_rels, + n_tags=None, + n_chars=None, + encoder='lstm', + feat=['tag', 'char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, False), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_arc_mlp=500, + n_rel_mlp=100, + mlp_dropout=.33, + scale=0, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + self.arc_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout) + self.arc_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout) + self.rel_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout) + self.rel_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout) + + self.arc_attn = Biaffine(n_in=n_arc_mlp, scale=scale, bias_x=True, bias_y=False) + self.rel_attn = Biaffine(n_in=n_rel_mlp, n_out=n_rels, bias_x=True, bias_y=True) + self.criterion = nn.CrossEntropyLoss() + + def forward(self, words, feats=None): + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first tensor of shape ``[batch_size, seq_len, seq_len]`` holds scores of all possible arcs. + The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds + scores of all possible labels on each arc. + """ + + x = self.encode(words, feats) + mask = words.ne(self.args.pad_index) if len(words.shape) < 3 else words.ne(self.args.pad_index).any(-1) + + arc_d = self.arc_mlp_d(x) + arc_h = self.arc_mlp_h(x) + rel_d = self.rel_mlp_d(x) + rel_h = self.rel_mlp_h(x) + + # [batch_size, seq_len, seq_len] + s_arc = self.arc_attn(arc_d, arc_h).masked_fill_(~mask.unsqueeze(1), MIN) + # [batch_size, seq_len, seq_len, n_rels] + s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1) + + return s_arc, s_rel + + def loss(self, s_arc, s_rel, arcs, rels, mask, partial=False): + r""" + Args: + s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible arcs. + s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each arc. + arcs (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard arcs. + rels (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + partial (bool): + ``True`` denotes the trees are partially annotated. Default: ``False``. + + Returns: + ~torch.Tensor: + The training loss. + """ + + if partial: + mask = mask & arcs.ge(0) + s_arc, arcs = s_arc[mask], arcs[mask] + s_rel, rels = s_rel[mask], rels[mask] + s_rel = s_rel[torch.arange(len(arcs)), arcs] + arc_loss = self.criterion(s_arc, arcs) + rel_loss = self.criterion(s_rel, rels) + + return arc_loss + rel_loss + + def decode(self, s_arc, s_rel, mask, tree=False, proj=False): + r""" + Args: + s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible arcs. + s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each arc. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + tree (bool): + If ``True``, ensures to output well-formed trees. Default: ``False``. + proj (bool): + If ``True``, ensures to output projective trees. Default: ``False``. + + Returns: + ~torch.LongTensor, ~torch.LongTensor: + Predicted arcs and labels of shape ``[batch_size, seq_len]``. + """ + + lens = mask.sum(1) + arc_preds = s_arc.argmax(-1) + bad = [not CoNLL.istree(seq[1:i+1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist())] + if tree and any(bad): + arc_preds[bad] = (DependencyCRF if proj else MatrixTree)(s_arc[bad], mask[bad].sum(-1)).argmax + rel_preds = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1) + + return arc_preds, rel_preds diff --git a/tania_scripts/supar/models/dep/biaffine/parser.py b/tania_scripts/supar/models/dep/biaffine/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b3093fc323131b372ffd744d5de9cf1e117cf9 --- /dev/null +++ b/tania_scripts/supar/models/dep/biaffine/parser.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- + +import os +from typing import Iterable, Union + +import torch + +from supar.models.dep.biaffine.model import BiaffineDependencyModel +from supar.models.dep.biaffine.transform import CoNLL +from supar.parser import Parser +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, PAD, UNK +from supar.utils.field import Field, RawField, SubwordField +from supar.utils.fn import ispunct +from supar.utils.logging import get_logger +from supar.utils.metric import AttachmentMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class BiaffineDependencyParser(Parser): + r""" + The implementation of Biaffine Dependency Parser :cite:`dozat-etal-2017-biaffine`. + """ + + NAME = 'biaffine-dependency' + MODEL = BiaffineDependencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.TAG = self.transform.CPOS + self.ARC, self.REL = self.transform.HEAD, self.transform.DEPREL + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + tree: bool = False, + proj: bool = False, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + tree: bool = True, + proj: bool = False, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + tree: bool = True, + proj: bool = False, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, _, *feats, arcs, rels = batch + mask = batch.mask + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_rel = self.model(words, feats) + loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial) + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> AttachmentMetric: + words, _, *feats, arcs, rels = batch + mask = batch.mask + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_rel = self.model(words, feats) + loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial) + arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj) + if self.args.partial: + mask &= arcs.ge(0) + # ignore all punctuation if not specified + if not self.args.punct: + mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words])) + return AttachmentMetric(loss, (arc_preds, rel_preds), (arcs, rels), mask) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, _, *feats = batch + mask, lens = batch.mask, (batch.lens - 1).tolist() + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_rel = self.model(words, feats) + arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj) + batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)] + batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)] + if self.args.prob: + batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.softmax(-1).unbind())] + return batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + r""" + Build a brand-new Parser, including initialization of all data fields and model parameters. + + Args: + path (str): + The path of the model to be saved. + min_freq (str): + The minimum frequency needed to include a token in the vocabulary. + Required if taking words as encoder input. + Default: 2. + fix_len (int): + The max length of all subword pieces. The excess part of each piece will be truncated. + Required if using CharLSTM/BERT. + Default: 20. + kwargs (Dict): + A dict holding the unconsumed arguments. + """ + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + TAG, CHAR, ELMO, BERT = None, None, None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True) + if 'tag' in args.feat: + TAG = Field('tags', bos=BOS) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len) + if 'elmo' in args.feat: + from allennlp.modules.elmo import batch_to_ids + ELMO = RawField('elmo') + ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device) + if 'bert' in args.feat: + t = TransformerTokenizer(args.bert) + BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t) + BERT.vocab = t.vocab + TEXT = RawField('texts') + ARC = Field('arcs', bos=BOS, use_vocab=False, fn=CoNLL.get_arcs) + REL = Field('rels', bos=BOS) + transform = CoNLL(FORM=(WORD, TEXT, CHAR, ELMO, BERT), CPOS=TAG, HEAD=ARC, DEPREL=REL) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if TAG is not None: + TAG.build(train) + if CHAR is not None: + CHAR.build(train) + REL.build(train) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_rels': len(REL.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'bert_pad_index': BERT.pad_index if BERT is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index + }) + logger.info(f"{transform}") + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser diff --git a/tania_scripts/supar/models/dep/biaffine/transform.py b/tania_scripts/supar/models/dep/biaffine/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..073572b57a2610a829253e698d4f4f1bbee05b85 --- /dev/null +++ b/tania_scripts/supar/models/dep/biaffine/transform.py @@ -0,0 +1,379 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +from io import StringIO +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union + +from supar.utils.logging import get_logger +from supar.utils.tokenizer import Tokenizer +from supar.utils.transform import Sentence, Transform + +if TYPE_CHECKING: + from supar.utils import Field + +logger = get_logger(__name__) + + +class CoNLL(Transform): + r""" + A :class:`CoNLL` object holds ten fields required for CoNLL-X data format :cite:`buchholz-marsi-2006-conll`. + Each field can be bound to one or more :class:`~supar.utils.field.Field` objects. + For example, ``FORM`` can contain both :class:`~supar.utils.field.Field` and :class:`~supar.utils.field.SubwordField` + to produce tensors for words and subwords. + + Attributes: + ID: + Token counter, starting at 1. + FORM: + Words in the sentence. + LEMMA: + Lemmas or stems (depending on the particular treebank) of words, or underscores if not available. + CPOS: + Coarse-grained part-of-speech tags, where the tagset depends on the treebank. + POS: + Fine-grained part-of-speech tags, where the tagset depends on the treebank. + FEATS: + Unordered set of syntactic and/or morphological features (depending on the particular treebank), + or underscores if not available. + HEAD: + Heads of the tokens, which are either values of ID or zeros. + DEPREL: + Dependency relations to the HEAD. + PHEAD: + Projective heads of tokens, which are either values of ID or zeros, or underscores if not available. + PDEPREL: + Dependency relations to the PHEAD, or underscores if not available. + """ + + fields = ['ID', 'FORM', 'LEMMA', 'CPOS', 'POS', 'FEATS', 'HEAD', 'DEPREL', 'PHEAD', 'PDEPREL'] + + def __init__( + self, + ID: Optional[Union[Field, Iterable[Field]]] = None, + FORM: Optional[Union[Field, Iterable[Field]]] = None, + LEMMA: Optional[Union[Field, Iterable[Field]]] = None, + CPOS: Optional[Union[Field, Iterable[Field]]] = None, + POS: Optional[Union[Field, Iterable[Field]]] = None, + FEATS: Optional[Union[Field, Iterable[Field]]] = None, + HEAD: Optional[Union[Field, Iterable[Field]]] = None, + DEPREL: Optional[Union[Field, Iterable[Field]]] = None, + PHEAD: Optional[Union[Field, Iterable[Field]]] = None, + PDEPREL: Optional[Union[Field, Iterable[Field]]] = None + ) -> CoNLL: + super().__init__() + + self.ID = ID + self.FORM = FORM + self.LEMMA = LEMMA + self.CPOS = CPOS + self.POS = POS + self.FEATS = FEATS + self.HEAD = HEAD + self.DEPREL = DEPREL + self.PHEAD = PHEAD + self.PDEPREL = PDEPREL + + @property + def src(self): + return self.FORM, self.LEMMA, self.CPOS, self.POS, self.FEATS + + @property + def tgt(self): + return self.HEAD, self.DEPREL, self.PHEAD, self.PDEPREL + + @classmethod + def get_arcs(cls, sequence, placeholder='_'): + return [-1 if i == placeholder else int(i) for i in sequence] + + @classmethod + def get_sibs(cls, sequence, placeholder='_'): + sibs = [[0] * (len(sequence) + 1) for _ in range(len(sequence) + 1)] + heads = [0] + [-1 if i == placeholder else int(i) for i in sequence] + + for i, hi in enumerate(heads[1:], 1): + for j, hj in enumerate(heads[i + 1:], i + 1): + di, dj = hi - i, hj - j + if hi >= 0 and hj >= 0 and hi == hj and di * dj > 0: + if abs(di) > abs(dj): + sibs[i][hi] = j + else: + sibs[j][hj] = i + break + return sibs[1:] + + @classmethod + def get_edges(cls, sequence): + edges = [[0] * (len(sequence) + 1) for _ in range(len(sequence) + 1)] + for i, s in enumerate(sequence, 1): + if s != '_': + for pair in s.split('|'): + edges[i][int(pair.split(':')[0])] = 1 + return edges + + @classmethod + def get_labels(cls, sequence): + labels = [[None] * (len(sequence) + 1) for _ in range(len(sequence) + 1)] + for i, s in enumerate(sequence, 1): + if s != '_': + for pair in s.split('|'): + edge, label = pair.split(':', 1) + labels[i][int(edge)] = label + return labels + + @classmethod + def build_relations(cls, chart): + sequence = ['_'] * len(chart) + for i, row in enumerate(chart): + pairs = [(j, label) for j, label in enumerate(row) if label is not None] + if len(pairs) > 0: + sequence[i] = '|'.join(f"{head}:{label}" for head, label in pairs) + return sequence + + @classmethod + def toconll(cls, tokens: List[Union[str, Tuple]]) -> str: + r""" + Converts a list of tokens to a string in CoNLL-X format with missing fields filled with underscores. + + Args: + tokens (List[Union[str, Tuple]]): + This can be either a list of words, word/pos pairs or word/lemma/pos triples. + + Returns: + A string in CoNLL-X format. + + Examples: + >>> print(CoNLL.toconll(['She', 'enjoys', 'playing', 'tennis', '.'])) + 1 She _ _ _ _ _ _ _ _ + 2 enjoys _ _ _ _ _ _ _ _ + 3 playing _ _ _ _ _ _ _ _ + 4 tennis _ _ _ _ _ _ _ _ + 5 . _ _ _ _ _ _ _ _ + + >>> print(CoNLL.toconll([('She', 'she', 'PRP'), + ('enjoys', 'enjoy', 'VBZ'), + ('playing', 'play', 'VBG'), + ('tennis', 'tennis', 'NN'), + ('.', '_', '.')])) + 1 She she PRP _ _ _ _ _ _ + 2 enjoys enjoy VBZ _ _ _ _ _ _ + 3 playing play VBG _ _ _ _ _ _ + 4 tennis tennis NN _ _ _ _ _ _ + 5 . _ . _ _ _ _ _ _ + + """ + + if isinstance(tokens[0], str): + s = '\n'.join([f"{i}\t{word}\t" + '\t'.join(['_'] * 8) + for i, word in enumerate(tokens, 1)]) + elif len(tokens[0]) == 2: + s = '\n'.join([f"{i}\t{word}\t_\t{tag}\t" + '\t'.join(['_'] * 6) + for i, (word, tag) in enumerate(tokens, 1)]) + elif len(tokens[0]) == 3: + s = '\n'.join([f"{i}\t{word}\t{lemma}\t{tag}\t" + '\t'.join(['_'] * 6) + for i, (word, lemma, tag) in enumerate(tokens, 1)]) + else: + raise RuntimeError(f"Invalid sequence {tokens}. Only list of str or list of word/pos/lemma tuples are support.") + return s + '\n' + + @classmethod + def isprojective(cls, sequence: List[int]) -> bool: + r""" + Checks if a dependency tree is projective. + This also works for partial annotation. + + Besides the obvious crossing arcs, the examples below illustrate two non-projective cases + which are hard to detect in the scenario of partial annotation. + + Args: + sequence (List[int]): + A list of head indices. + + Returns: + ``True`` if the tree is projective, ``False`` otherwise. + + Examples: + >>> CoNLL.isprojective([2, -1, 1]) # -1 denotes un-annotated cases + False + >>> CoNLL.isprojective([3, -1, 2]) + False + """ + + pairs = [(h, d) for d, h in enumerate(sequence, 1) if h >= 0] + for i, (hi, di) in enumerate(pairs): + for hj, dj in pairs[i + 1:]: + (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj]) + if li <= hj <= ri and hi == dj: + return False + if lj <= hi <= rj and hj == di: + return False + if (li < lj < ri or li < rj < ri) and (li - lj) * (ri - rj) > 0: + return False + return True + + @classmethod + def istree(cls, sequence: List[int], proj: bool = False, multiroot: bool = False) -> bool: + r""" + Checks if the arcs form an valid dependency tree. + + Args: + sequence (List[int]): + A list of head indices. + proj (bool): + If ``True``, requires the tree to be projective. Default: ``False``. + multiroot (bool): + If ``False``, requires the tree to contain only a single root. Default: ``True``. + + Returns: + ``True`` if the arcs form an valid tree, ``False`` otherwise. + + Examples: + >>> CoNLL.istree([3, 0, 0, 3], multiroot=True) + True + >>> CoNLL.istree([3, 0, 0, 3], proj=True) + False + """ + + from supar.structs.fn import tarjan + if proj and not cls.isprojective(sequence): + return False + n_roots = sum(head == 0 for head in sequence) + if n_roots == 0: + return False + if not multiroot and n_roots > 1: + return False + if any(i == head for i, head in enumerate(sequence, 1)): + return False + return next(tarjan(sequence), None) is None + + def load( + self, + data: Union[str, Iterable], + lang: Optional[str] = None, + proj: bool = False, + **kwargs + ) -> Iterable[CoNLLSentence]: + r""" + Loads the data in CoNLL-X format. + Also supports for loading data from CoNLL-U file with comments and non-integer IDs. + + Args: + data (Union[str, Iterable]): + A filename or a list of instances. + lang (str): + Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. + ``None`` if tokenization is not required. + Default: ``None``. + proj (bool): + If ``True``, discards all non-projective sentences. Default: ``False``. + + Returns: + A list of :class:`CoNLLSentence` instances. + """ + + isconll = False + if lang is not None: + tokenizer = Tokenizer(lang) + if isinstance(data, str) and os.path.exists(data): + f = open(data) + if data.endswith('.txt'): + lines = (i + for s in f + if len(s) > 1 + for i in StringIO(self.toconll(s.split() if lang is None else tokenizer(s)) + '\n')) + else: + lines, isconll = f, True + else: + if lang is not None: + data = [tokenizer(s) for s in ([data] if isinstance(data, str) else data)] + else: + data = [data] if isinstance(data[0], str) else data + lines = (i for s in data for i in StringIO(self.toconll(s) + '\n')) + + index, sentence = 0, [] + for line in lines: + line = line.strip() + if len(line) == 0: + sentence = CoNLLSentence(self, sentence, index) + if isconll and self.training and proj and not self.isprojective(list(map(int, sentence.arcs))): + logger.warning(f"Sentence {index} is not projective. Discarding it!") + else: + yield sentence + index += 1 + sentence = [] + else: + sentence.append(line) + + +class CoNLLSentence(Sentence): + r""" + Sencence in CoNLL-X format. + + Args: + transform (CoNLL): + A :class:`~supar.utils.transform.CoNLL` object. + lines (List[str]): + A list of strings composing a sentence in CoNLL-X format. + Comments and non-integer IDs are permitted. + index (Optional[int]): + Index of the sentence in the corpus. Default: ``None``. + + Examples: + >>> lines = ['# text = But I found the location wonderful and the neighbors very kind.', + '1\tBut\t_\t_\t_\t_\t_\t_\t_\t_', + '2\tI\t_\t_\t_\t_\t_\t_\t_\t_', + '3\tfound\t_\t_\t_\t_\t_\t_\t_\t_', + '4\tthe\t_\t_\t_\t_\t_\t_\t_\t_', + '5\tlocation\t_\t_\t_\t_\t_\t_\t_\t_', + '6\twonderful\t_\t_\t_\t_\t_\t_\t_\t_', + '7\tand\t_\t_\t_\t_\t_\t_\t_\t_', + '7.1\tfound\t_\t_\t_\t_\t_\t_\t_\t_', + '8\tthe\t_\t_\t_\t_\t_\t_\t_\t_', + '9\tneighbors\t_\t_\t_\t_\t_\t_\t_\t_', + '10\tvery\t_\t_\t_\t_\t_\t_\t_\t_', + '11\tkind\t_\t_\t_\t_\t_\t_\t_\t_', + '12\t.\t_\t_\t_\t_\t_\t_\t_\t_'] + >>> sentence = CoNLLSentence(transform, lines) # fields in transform are built from ptb. + >>> sentence.arcs = [3, 3, 0, 5, 6, 3, 6, 9, 11, 11, 6, 3] + >>> sentence.rels = ['cc', 'nsubj', 'root', 'det', 'nsubj', 'xcomp', + 'cc', 'det', 'dep', 'advmod', 'conj', 'punct'] + >>> sentence + # text = But I found the location wonderful and the neighbors very kind. + 1 But _ _ _ _ 3 cc _ _ + 2 I _ _ _ _ 3 nsubj _ _ + 3 found _ _ _ _ 0 root _ _ + 4 the _ _ _ _ 5 det _ _ + 5 location _ _ _ _ 6 nsubj _ _ + 6 wonderful _ _ _ _ 3 xcomp _ _ + 7 and _ _ _ _ 6 cc _ _ + 7.1 found _ _ _ _ _ _ _ _ + 8 the _ _ _ _ 9 det _ _ + 9 neighbors _ _ _ _ 11 dep _ _ + 10 very _ _ _ _ 11 advmod _ _ + 11 kind _ _ _ _ 6 conj _ _ + 12 . _ _ _ _ 3 punct _ _ + """ + + def __init__(self, transform: CoNLL, lines: List[str], index: Optional[int] = None) -> CoNLLSentence: + super().__init__(transform, index) + + self.values = [] + # record annotations for post-recovery + self.annotations = dict() + + for i, line in enumerate(lines): + value = line.split('\t') + if value[0].startswith('#') or not value[0].isdigit(): + self.annotations[-i - 1] = line + else: + self.annotations[len(self.values)] = line + self.values.append(value) + self.values = list(zip(*self.values)) + + def __repr__(self): + # cover the raw lines + merged = {**self.annotations, + **{i: '\t'.join(map(str, line)) + for i, line in enumerate(zip(*self.values))}} + return '\n'.join(merged.values()) + '\n' diff --git a/tania_scripts/supar/models/dep/crf/__init__.py b/tania_scripts/supar/models/dep/crf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..27cae45e1d3c23070acfa28ffb75aedb99bf560e --- /dev/null +++ b/tania_scripts/supar/models/dep/crf/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import CRFDependencyModel +from .parser import CRFDependencyParser + +__all__ = ['CRFDependencyModel', 'CRFDependencyParser'] diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5810074dccba6a92cf481e86ef69ece8be74592c Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3609fddab845fa29cbbf98125bf4f14117dacd7c Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..350a941b0ad158a2c7ec7f74b109e4200050bd1e Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82c5f8a1f3181f461f6d3f7c6d863af8093a12ca Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0a41a163099e2b9eb112143902fdc323def85a0 Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b93724d0f7d60bbc3af66307c55579f66b3ed2c5 Binary files /dev/null and b/tania_scripts/supar/models/dep/crf/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/crf/model.py b/tania_scripts/supar/models/dep/crf/model.py new file mode 100644 index 0000000000000000000000000000000000000000..472bd359efe2aa3b0670e0e8ef964d86c3b22822 --- /dev/null +++ b/tania_scripts/supar/models/dep/crf/model.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +import torch +from supar.models.dep.biaffine.model import BiaffineDependencyModel +from supar.structs import DependencyCRF, MatrixTree + + +class CRFDependencyModel(BiaffineDependencyModel): + r""" + The implementation of first-order CRF Dependency Parser + :cite:`zhang-etal-2020-efficient,ma-hovy-2017-neural,koo-etal-2007-structured`). + + Args: + n_words (int): + The size of the word vocabulary. + n_rels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layer. Default: .33. + n_arc_mlp (int): + Arc MLP size. Default: 500. + n_rel_mlp (int): + Label MLP size. Default: 100. + mlp_dropout (float): + The dropout ratio of MLP layers. Default: .33. + scale (float): + Scaling factor for affine scores. Default: 0. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + proj (bool): + If ``True``, takes :class:`DependencyCRF` as inference layer, :class:`MatrixTree` otherwise. + Default: ``True``. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def loss(self, s_arc, s_rel, arcs, rels, mask, mbr=True, partial=False): + r""" + Args: + s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible arcs. + s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each arc. + arcs (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard arcs. + rels (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + mbr (bool): + If ``True``, returns marginals for MBR decoding. Default: ``True``. + partial (bool): + ``True`` denotes the trees are partially annotated. Default: ``False``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The training loss and + original arc scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise. + """ + + CRF = DependencyCRF if self.args.proj else MatrixTree + arc_dist = CRF(s_arc, mask.sum(-1)) + arc_loss = -arc_dist.log_prob(arcs, partial=partial).sum() / mask.sum() + arc_probs = arc_dist.marginals if mbr else s_arc + # -1 denotes un-annotated arcs + if partial: + mask = mask & arcs.ge(0) + s_rel, rels = s_rel[mask], rels[mask] + s_rel = s_rel[torch.arange(len(rels)), arcs[mask]] + rel_loss = self.criterion(s_rel, rels) + loss = arc_loss + rel_loss + return loss, arc_probs diff --git a/tania_scripts/supar/models/dep/crf/parser.py b/tania_scripts/supar/models/dep/crf/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed16896b85dac689a14c92d5b89d6732fc85640 --- /dev/null +++ b/tania_scripts/supar/models/dep/crf/parser.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- + +from typing import Iterable, Union + +import torch + +from supar.models.dep.biaffine.parser import BiaffineDependencyParser +from supar.models.dep.crf.model import CRFDependencyModel +from supar.structs import DependencyCRF, MatrixTree +from supar.utils import Config +from supar.utils.fn import ispunct +from supar.utils.logging import get_logger +from supar.utils.metric import AttachmentMetric +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class CRFDependencyParser(BiaffineDependencyParser): + r""" + The implementation of first-order CRF Dependency Parser :cite:`zhang-etal-2020-efficient`. + """ + + NAME = 'crf-dependency' + MODEL = CRFDependencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + mbr: bool = True, + tree: bool = False, + proj: bool = False, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + mbr: bool = True, + tree: bool = True, + proj: bool = True, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + mbr: bool = True, + tree: bool = True, + proj: bool = True, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, _, *feats, arcs, rels = batch + mask = batch.mask + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_rel = self.model(words, feats) + loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.mbr, self.args.partial) + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> AttachmentMetric: + words, _, *feats, arcs, rels = batch + mask = batch.mask + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_rel = self.model(words, feats) + loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.mbr, self.args.partial) + arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj) + if self.args.partial: + mask &= arcs.ge(0) + # ignore all punctuation if not specified + if not self.args.punct: + mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words])) + return AttachmentMetric(loss, (arc_preds, rel_preds), (arcs, rels), mask) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + CRF = DependencyCRF if self.args.proj else MatrixTree + words, _, *feats = batch + mask, lens = batch.mask, batch.lens - 1 + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_rel = self.model(words, feats) + s_arc = CRF(s_arc, lens).marginals if self.args.mbr else s_arc + arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj) + lens = lens.tolist() + batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)] + batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)] + if self.args.prob: + arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1) + batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())] + return batch diff --git a/tania_scripts/supar/models/dep/crf2o/__init__.py b/tania_scripts/supar/models/dep/crf2o/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2acf9ce4dcc45c96c270384714470b59e94e927 --- /dev/null +++ b/tania_scripts/supar/models/dep/crf2o/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import CRF2oDependencyModel +from .parser import CRF2oDependencyParser + +__all__ = ['CRF2oDependencyModel', 'CRF2oDependencyParser'] diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb6754d2364431258038d51a7d60f97240ef9d09 Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cf033823fb198f1abbc2c5fe24ab7427a6e3819 Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6243392a5a1c5d884a93eb8cd18fbcf1edad6ec4 Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7586c6f2981ac0dc9e8dbe7252d6f9ed741d5805 Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cf299268ee7bd3bd211e8bd3d917638bd104180 Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec22eb9df296b742697cdd799d431ff90162da41 Binary files /dev/null and b/tania_scripts/supar/models/dep/crf2o/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/crf2o/model.py b/tania_scripts/supar/models/dep/crf2o/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1b53bdd4f7db62faf30350e50877f7d61b4a6cc2 --- /dev/null +++ b/tania_scripts/supar/models/dep/crf2o/model.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from supar.models.dep.biaffine.model import BiaffineDependencyModel +from supar.models.dep.biaffine.transform import CoNLL +from supar.modules import MLP, Biaffine, Triaffine +from supar.structs import Dependency2oCRF, MatrixTree +from supar.utils import Config +from supar.utils.common import MIN + + +class CRF2oDependencyModel(BiaffineDependencyModel): + r""" + The implementation of second-order CRF Dependency Parser :cite:`zhang-etal-2020-efficient`. + + Args: + n_words (int): + The size of the word vocabulary. + n_rels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layer. Default: .33. + n_arc_mlp (int): + Arc MLP size. Default: 500. + n_sib_mlp (int): + Sibling MLP size. Default: 100. + n_rel_mlp (int): + Label MLP size. Default: 100. + mlp_dropout (float): + The dropout ratio of MLP layers. Default: .33. + scale (float): + Scaling factor for affine scores. Default: 0. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_rels, + n_tags=None, + n_chars=None, + encoder='lstm', + feat=['char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, False), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_arc_mlp=500, + n_sib_mlp=100, + n_rel_mlp=100, + mlp_dropout=.33, + scale=0, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + self.arc_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout) + self.arc_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout) + self.sib_mlp_s = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout) + self.sib_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout) + self.sib_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout) + self.rel_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout) + self.rel_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout) + + self.arc_attn = Biaffine(n_in=n_arc_mlp, scale=scale, bias_x=True, bias_y=False) + self.sib_attn = Triaffine(n_in=n_sib_mlp, scale=scale, bias_x=True, bias_y=True) + self.rel_attn = Biaffine(n_in=n_rel_mlp, n_out=n_rels, bias_x=True, bias_y=True) + self.criterion = nn.CrossEntropyLoss() + + def forward(self, words, feats=None): + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor, ~torch.Tensor: + Scores of all possible arcs (``[batch_size, seq_len, seq_len]``), + dependent-head-sibling triples (``[batch_size, seq_len, seq_len, seq_len]``) and + all possible labels on each arc (``[batch_size, seq_len, seq_len, n_labels]``). + """ + + x = self.encode(words, feats) + mask = words.ne(self.args.pad_index) if len(words.shape) < 3 else words.ne(self.args.pad_index).any(-1) + + arc_d = self.arc_mlp_d(x) + arc_h = self.arc_mlp_h(x) + sib_s = self.sib_mlp_s(x) + sib_d = self.sib_mlp_d(x) + sib_h = self.sib_mlp_h(x) + rel_d = self.rel_mlp_d(x) + rel_h = self.rel_mlp_h(x) + + # [batch_size, seq_len, seq_len] + s_arc = self.arc_attn(arc_d, arc_h).masked_fill_(~mask.unsqueeze(1), MIN) + # [batch_size, seq_len, seq_len, seq_len] + s_sib = self.sib_attn(sib_s, sib_d, sib_h).permute(0, 3, 1, 2) + # [batch_size, seq_len, seq_len, n_rels] + s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1) + + return s_arc, s_sib, s_rel + + def loss(self, s_arc, s_sib, s_rel, arcs, sibs, rels, mask, mbr=True, partial=False): + r""" + Args: + s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible arcs. + s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. + Scores of all possible dependent-head-sibling triples. + s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each arc. + arcs (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard arcs. + sibs (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. + The tensor of gold-standard siblings. + rels (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + mbr (bool): + If ``True``, returns marginals for MBR decoding. Default: ``True``. + partial (bool): + ``True`` denotes the trees are partially annotated. Default: ``False``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The training loss and + original arc scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise. + """ + + arc_dist = Dependency2oCRF((s_arc, s_sib), mask.sum(-1)) + arc_loss = -arc_dist.log_prob((arcs, sibs), partial=partial).sum() / mask.sum() + if mbr: + s_arc, s_sib = arc_dist.marginals + # -1 denotes un-annotated arcs + if partial: + mask = mask & arcs.ge(0) + s_rel, rels = s_rel[mask], rels[mask] + s_rel = s_rel[torch.arange(len(rels)), arcs[mask]] + rel_loss = self.criterion(s_rel, rels) + loss = arc_loss + rel_loss + return loss, s_arc, s_sib + + def decode(self, s_arc, s_sib, s_rel, mask, tree=False, mbr=True, proj=False): + r""" + Args: + s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible arcs. + s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. + Scores of all possible dependent-head-sibling triples. + s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each arc. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + tree (bool): + If ``True``, ensures to output well-formed trees. Default: ``False``. + mbr (bool): + If ``True``, performs MBR decoding. Default: ``True``. + proj (bool): + If ``True``, ensures to output projective trees. Default: ``False``. + + Returns: + ~torch.LongTensor, ~torch.LongTensor: + Predicted arcs and labels of shape ``[batch_size, seq_len]``. + """ + + lens = mask.sum(1) + arc_preds = s_arc.argmax(-1) + bad = [not CoNLL.istree(seq[1:i+1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist())] + if tree and any(bad): + if proj: + arc_preds[bad] = Dependency2oCRF((s_arc[bad], s_sib[bad]), mask[bad].sum(-1)).argmax + else: + arc_preds[bad] = MatrixTree(s_arc[bad], mask[bad].sum(-1)).argmax + rel_preds = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1) + + return arc_preds, rel_preds diff --git a/tania_scripts/supar/models/dep/crf2o/parser.py b/tania_scripts/supar/models/dep/crf2o/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..c822b2a154aa326c743fae85b157a476c9fdffcb --- /dev/null +++ b/tania_scripts/supar/models/dep/crf2o/parser.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- + +import os +from typing import Iterable, Union + +import torch + +from supar.models.dep.biaffine.parser import BiaffineDependencyParser +from supar.models.dep.biaffine.transform import CoNLL +from supar.models.dep.crf2o.model import CRF2oDependencyModel +from supar.structs import Dependency2oCRF +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, PAD, UNK +from supar.utils.field import ChartField, Field, RawField, SubwordField +from supar.utils.fn import ispunct +from supar.utils.logging import get_logger +from supar.utils.metric import AttachmentMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class CRF2oDependencyParser(BiaffineDependencyParser): + r""" + The implementation of second-order CRF Dependency Parser :cite:`zhang-etal-2020-efficient`. + """ + + NAME = 'crf2o-dependency' + MODEL = CRF2oDependencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + mbr: bool = True, + tree: bool = False, + proj: bool = False, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + mbr: bool = True, + tree: bool = True, + proj: bool = True, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + mbr: bool = True, + tree: bool = True, + proj: bool = True, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, _, *feats, arcs, sibs, rels = batch + mask = batch.mask + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_sib, s_rel = self.model(words, feats) + loss, *_ = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial) + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> AttachmentMetric: + words, _, *feats, arcs, sibs, rels = batch + mask = batch.mask + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_sib, s_rel = self.model(words, feats) + loss, s_arc, s_sib = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial) + arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj) + if self.args.partial: + mask &= arcs.ge(0) + # ignore all punctuation if not specified + if not self.args.punct: + mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words])) + return AttachmentMetric(loss, (arc_preds, rel_preds), (arcs, rels), mask) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, _, *feats = batch + mask, lens = batch.mask, batch.lens - 1 + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_sib, s_rel = self.model(words, feats) + s_arc, s_sib = Dependency2oCRF((s_arc, s_sib), lens).marginals if self.args.mbr else (s_arc, s_sib) + arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj) + lens = lens.tolist() + batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)] + batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)] + if self.args.prob: + arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1) + batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())] + return batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + r""" + Build a brand-new Parser, including initialization of all data fields and model parameters. + + Args: + path (str): + The path of the model to be saved. + min_freq (str): + The minimum frequency needed to include a token in the vocabulary. Default: 2. + fix_len (int): + The max length of all subword pieces. The excess part of each piece will be truncated. + Required if using CharLSTM/BERT. + Default: 20. + kwargs (Dict): + A dict holding the unconsumed arguments. + """ + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + TAG, CHAR, ELMO, BERT = None, None, None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True) + if 'tag' in args.feat: + TAG = Field('tags', bos=BOS) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len) + if 'elmo' in args.feat: + from allennlp.modules.elmo import batch_to_ids + ELMO = RawField('elmo') + ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device) + if 'bert' in args.feat: + t = TransformerTokenizer(args.bert) + BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t) + BERT.vocab = t.vocab + TEXT = RawField('texts') + ARC = Field('arcs', bos=BOS, use_vocab=False, fn=CoNLL.get_arcs) + SIB = ChartField('sibs', bos=BOS, use_vocab=False, fn=CoNLL.get_sibs) + REL = Field('rels', bos=BOS) + transform = CoNLL(FORM=(WORD, TEXT, CHAR, ELMO, BERT), CPOS=TAG, HEAD=(ARC, SIB), DEPREL=REL) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if TAG is not None: + TAG.build(train) + if CHAR is not None: + CHAR.build(train) + REL.build(train) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_rels': len(REL.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'bert_pad_index': BERT.pad_index if BERT is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index + }) + logger.info(f"{transform}") + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser diff --git a/tania_scripts/supar/models/dep/eager/__init__.py b/tania_scripts/supar/models/dep/eager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e985987237e25784b579e74d88c25f8a1b3de7 --- /dev/null +++ b/tania_scripts/supar/models/dep/eager/__init__.py @@ -0,0 +1,2 @@ +from .parser import ArcEagerDependencyParser +from .model import ArcEagerDependencyModel \ No newline at end of file diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69b6b347907cd94bda81a4738bf8e7e85860c421 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4eb2047d9191fa7725d4f6c991fbd42ac25afea Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c35f8752a167848718305af293b7ba376172b991 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd043bef1f8af3f5ca69f649ed48a44101524e37 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f44d08db4b5106cd902edd53576ffd54fc22572 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c53842f2152b6d56e2c6e8626ae48e37586c53a4 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35d39d5378eae49dfe2f72c59831aec35703b2ca Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbdee27774e2be073b02a1de3f1e5540e6c700cf Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/__pycache__/transform.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/model.py b/tania_scripts/supar/models/dep/eager/model.py new file mode 100644 index 0000000000000000000000000000000000000000..300e82969f98af688a3f1700c7aab969ff2b3d75 --- /dev/null +++ b/tania_scripts/supar/models/dep/eager/model.py @@ -0,0 +1,198 @@ +import torch +import torch.nn as nn +from supar.model import Model +from supar.modules import MLP, DecoderLSTM +from supar.utils import Config +from typing import Tuple, List + +class ArcEagerDependencyModel(Model): + + def __init__(self, + n_words, + n_transitions, + n_trels, + n_tags=None, + n_chars=None, + encoder='lstm', + feat=['char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, False), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_arc_mlp=500, + n_rel_mlp=100, + mlp_dropout=.33, + scale=0, + pad_index=0, + unk_index=1, + n_decoder_layers=4, + **kwargs): + super().__init__(**Config().update(locals())) + + # create decoder for buffer front, stack top and rels + self.transition_decoder = self.rel_decoder = None, None + + stack_size, buffer_size = [self.args.n_encoder_hidden//2] * 2 if (self.args.n_encoder_hidden % 2) == 0 \ + else [self.args.n_encoder_hidden//2, self.args.n_encoder_hidden//2+1] + + # create projection to reduce dimensionality of the encoder + self.stack_proj = MLP( + n_in=self.args.n_encoder_hidden, n_out=stack_size, + dropout=mlp_dropout) + self.buffer_proj = MLP( + n_in=self.args.n_encoder_hidden, n_out=buffer_size, + dropout=mlp_dropout + ) + + if self.args.decoder == 'lstm': + decoder = lambda out_dim: DecoderLSTM( + self.args.n_encoder_hidden, self.args.n_encoder_hidden, out_dim, + self.args.n_decoder_layers, dropout=mlp_dropout, device=self.device + ) + else: + decoder = lambda out_dim: MLP( + n_in=self.args.n_encoder_hidden, n_out=out_dim, dropout=mlp_dropout + ) + + self.transition_decoder = decoder(n_transitions) + self.trel_decoder = decoder(n_trels) + + # create delay projection + if self.args.delay != 0: + self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay + 1), + n_out=self.args.n_encoder_hidden, dropout=mlp_dropout) + + # create PoS tagger + if self.args.encoder in ['lstm', 'bert']: + self.pos_tagger = DecoderLSTM( + self.args.n_encoder_hidden, self.args.n_encoder_hidden, self.args.n_tags, + num_layers=1, dropout=mlp_dropout, device=self.device + ) + else: + #args.encoder is bert + self.pos_tagger = nn.Identity() + + self.criterion = nn.CrossEntropyLoss() + + def encoder_forward(self, words: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Applies encoding forward pass. Maps a tensor of word indices (`words`) to their corresponding neural + representation. + Args: + words: torch.IntTensor ~ [batch_size, bos + pad(seq_len) + eos + delay] + feats: List[torch.Tensor] + lens: List[int] + + Returns: x, qloss + x: torch.FloatTensor ~ [batch_size, bos + pad(seq_len) + eos, embed_dim] + qloss: torch.FloatTensor ~ 1 + + """ + + x = super().encode(words, feats) + s_tag = self.pos_tagger(x[:, 1:-(1+self.args.delay), :]) + + # adjust lengths to allow delay predictions + # x ~ [batch_size, bos + pad(seq_len) + eos, embed_dim] + if self.args.delay != 0: + x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i), :] for i in range(self.args.delay + 1)], dim=2) + x = self.delay_proj(x) + + # pass through vector quantization + x, qloss = self.vq_forward(x) + return x, s_tag, qloss + + def decoder_forward(self, x: torch.Tensor, stack_top: torch.Tensor, buffer_front: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Args: + x: torch.FloatTensor ~ [batch_size, bos + pad(seq_len) + eos, embed_dim] + stack_top: torch.IntTensor ~ [batch_size, pad(tr_len)] + buffer_front: torch.IntTensor ~ [batch_size, pad(tr_len)] + + Returns: s_transition, s_trel + s_transition: torch.FloatTensor ~ [batch_size, pad(tr_len), n_transitions] + s_trel: torch.FloatTensor ~ [batch_size, pad(tr_len), n_trels] + """ + batch_size = x.shape[0] + + # obtain encoded embeddings for stack_top and buffer_front + stack_top = torch.stack([x[i, stack_top[i], :] for i in range(batch_size)]) + buffer_front = torch.stack([x[i, buffer_front[i], :] for i in range(batch_size)]) + + # pass through projections + stack_top = self.stack_proj(stack_top) + buffer_front = self.buffer_proj(buffer_front) + + # stack_top ~ [batch_size, pad(tr_len), embed_dim//2] + # buffer_front ~ [batch_size, pad(tr_len), embed_dim//2] + # x ~ [batch_size, pad(tr_len), embed_dim] + x = torch.concat([stack_top, buffer_front], dim=-1) + + # s_transition ~ [batch_size, pad(tr_len), n_transitions] + # s_trel = [batch_size, pad(tr_len), n_trels] + s_transition = self.transition_decoder(x) + s_trel = self.trel_decoder(x) + + return s_transition, s_trel + + def forward(self, words: torch.Tensor, stack_top: torch.Tensor, buffer_front: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Args: + words: torch.IntTensor ~ [batch_size, bos + pad(seq_len) + eos + delay]. + stack_top: torch.IntTensor ~ [batch_size, pad(tr_len)] + buffer_front: torch.IntTensor ~ [batch_size, pad(tr_len)] + feats: List[torch.Tensor] + + Returns: s_transition, s_trel, qloss + s_transition: torch.FloatTensor ~ [batch_size, pad(tr_len), n_transitions] + s_trel: torch.FloatTensor ~ [batch_size, pad(tr_len), n_trels] + qloss: torch.FloatTensor ~ 1 + """ + x, s_tag, qloss = self.encoder_forward(words, feats) + + s_transition, s_trel = self.decoder_forward(x, stack_top, buffer_front) + return s_transition, s_trel, s_tag, qloss + + def decode(self, s_transition: torch.Tensor, s_trel: torch.Tensor, exclude: list = None): + transition_preds = s_transition.argsort(-1, descending=True) + if exclude: + s_trel[:, :, exclude] = -1 + trel_preds = s_trel.argmax(-1) + return transition_preds, trel_preds + + def decode_stag(self, s_tag: torch.Tensor): + stag_preds = s_tag.argmax(-1) + #stag_preds = s_tag.argsort(-1, descending=True) + return stag_preds + + def loss(self, s_transition: torch.Tensor, s_trel: torch.Tensor, s_tag, + transitions: torch.Tensor, trels: torch.Tensor, tags, + smask: torch.Tensor, trmask: torch.Tensor, TRANSITION): + s_transition, transitions = s_transition[trmask], transitions[trmask] + s_trel, trels = s_trel[trmask], trels[trmask] + + # remove those values in trels that correspond to shift and reduce actions + transition_pred = TRANSITION.vocab[s_transition.argmax(-1).flatten().tolist()] + trel_mask = torch.tensor(list(map(lambda x: x not in ['reduce', 'shift'], transition_pred))) + s_trel, trels = s_trel[trel_mask], trels[trel_mask] + + tag_loss = self.criterion(s_tag[smask], tags[smask]) if self.args.encoder == 'lstm' else torch.tensor(0).to(self.device) + transition_loss = self.criterion(s_transition, transitions) + trel_loss = self.criterion(s_trel, trels) + + return transition_loss + trel_loss + tag_loss diff --git a/tania_scripts/supar/models/dep/eager/oracle/__init__.py b/tania_scripts/supar/models/dep/eager/oracle/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cf072a90eda87b6f954a0a161965c253cafd1b1 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13ce2553293d74a70524325df69e92ffe8f7771d Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93488668e82fa23ba9e3ff1768cb420f90132ce3 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09ee22f8790adc5bbc67d6acfbc51c23685850e4 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/arceager.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41d88bf55396dcd614cb06abfdcd6cdd80a64562 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c22b12dc825c9a45c4c9b0f8e44b31477268e653 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/buffer.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fce37dccace89b6d7436c953040e05640fbe754 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9534220075a54bc9b129b311b945ea708d02dc8 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/dependency.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d76682cca61dd66a9aa936ace916052a3ae2e0f7 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a8d46ef4a1c5317321c19f8323f13e6319c4fe0 Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/node.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-310.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f963a4cfe22b0dc6246066d745d20c9344ac44ed Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-311.pyc b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a362f4b907de1716061719405def54b641c7abd Binary files /dev/null and b/tania_scripts/supar/models/dep/eager/oracle/__pycache__/stack.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/eager/oracle/arceager.py b/tania_scripts/supar/models/dep/eager/oracle/arceager.py new file mode 100644 index 0000000000000000000000000000000000000000..b65dad15199cb5d6338afdb3b84b6a18ace20918 --- /dev/null +++ b/tania_scripts/supar/models/dep/eager/oracle/arceager.py @@ -0,0 +1,219 @@ +from supar.models.dep.eager.oracle.stack import Stack +from supar.models.dep.eager.oracle.buffer import Buffer +from supar.models.dep.eager.oracle.node import Node +from supar.models.dep.eager.oracle.dependency import Transition +from typing import List, Dict + +class ArcEagerEncoder: + def __init__(self, bos: str, eos: str): + self.stack = None + self.buffer = None + self.transitions = [] + self.dependencies = [] + self.nodes_assigned = [] + self.bos, self.eos = bos, eos + self.shift_token = '<shift>' + self.reduce_token = '<reduce>' + self.n = 0 + + def left_arc(self): + # head = buffer_front, dependent = stack_top (stack_top <- buffer_front) + stack_top = self.stack.pop() + buffer_front = self.buffer.get() + self.transitions.append( + Transition(type='left-arc', stack_top=stack_top, buffer_front=buffer_front, deprel=stack_top.DEPREL) + ) + self.dependencies.append( + Node(ID=stack_top.ID, FORM=stack_top.FORM, UPOS=stack_top.UPOS, HEAD=buffer_front.ID, DEPREL=stack_top.DEPREL) + ) + self.nodes_assigned.append(stack_top.ID) + return self.transitions + + def right_arc(self): + # head = stack_top, dependent = buffer_front (stack_top -> buffer_front) + stack_top = self.stack.get() + buffer_front = self.buffer.remove() + self.stack.push(buffer_front) + self.transitions.append( + Transition(type='right-arc', stack_top=stack_top, buffer_front=buffer_front, deprel=buffer_front.DEPREL) + ) + self.dependencies.append( + Node(ID=buffer_front.ID, FORM=buffer_front.FORM, UPOS=buffer_front.UPOS, HEAD=stack_top.ID, DEPREL=buffer_front.DEPREL) + ) + self.nodes_assigned.append(buffer_front.ID) + return self.transitions + + def shift(self): + front_item = self.buffer.remove() + stack_top = self.stack.get() + self.stack.push(front_item) + self.transitions.append( + Transition(type='shift', stack_top=stack_top, buffer_front=front_item, deprel=front_item.DEPREL)) + return self.transitions + + def reduce(self): + stack_top = self.stack.get() + buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(self.n, self.eos) + self.stack.pop() + self.transitions.append( + Transition(type='reduce', stack_top=stack_top, buffer_front=buffer_front, deprel=stack_top.DEPREL) + ) + return self.transitions + + def next_action(self): + stack_top = self.stack.get() + try: + buffer_front = self.buffer.get() + except IndexError: + if stack_top.ID in self.nodes_assigned: + return self.reduce() + return None + + if buffer_front.ID == stack_top.HEAD: + return self.left_arc() + elif (buffer_front.ID not in self.nodes_assigned) and (stack_top.ID == buffer_front.HEAD): + return self.right_arc() + elif (stack_top.ID in self.nodes_assigned) and (buffer_front.HEAD in [node.ID for node in self.stack.items]): + return self.reduce() + elif (stack_top.ID in self.nodes_assigned) and (buffer_front.ID in [node.HEAD for node in self.stack.items]): + return self.reduce() + else: + return self.shift() + + def encode(self, sentence: List[Node]): + # create stack and buffer + self.stack = Stack([Node.create_root(self.bos)]) + self.buffer = Buffer(sentence.copy()) + self.n = len(sentence) + + # reset + self.transitions, self.dependencies = [], [] + + next_action = self.next_action() + while next_action: + next_action = self.next_action() + + # remove values + self.dependencies = sorted(self.dependencies, key=lambda dep: dep.ID) + return self.transitions + + + +class ArcEagerDecoder: + def __init__(self, sentence: List[Node], bos: str, eos: str, unk: str): + self.sentence = sentence.copy() + self.decoded_nodes = [Node(ID=node.ID, FORM=node.FORM, UPOS=node.UPOS, HEAD=0, DEPREL=unk) for node in sentence] + self.transitions = list() + self.nodes_assigned = list() + self.stack = Stack([Node.create_root(bos)]) + self.buffer = Buffer(sentence.copy()) + self.bos, self.eos, self.unk = bos, eos, unk + self.shift_token, self.reduce_token = '<shift>', '<reduce>' + + + def left_arc(self, deprel: str): + # head = buffer_front, dependent = stack_top (stack_top <- buffer_front) + stack_top = self.stack.pop() + buffer_front = self.buffer.get() + self.nodes_assigned.append(stack_top.ID) + self.transitions.append( + Transition(type='left-arc', stack_top=stack_top, buffer_front=buffer_front, deprel=deprel) + ) + self.decoded_nodes[stack_top.ID - 1].HEAD = buffer_front.ID + self.decoded_nodes[stack_top.ID - 1].DEPREL = deprel + # get next states + stack_top = self.stack.get() + buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(len(self.sentence), self.eos) + return stack_top, buffer_front + + def right_arc(self, deprel: str): + # head = stack_top, dependent = buffer_front (stack_top -> buffer_front) + stack_top = self.stack.get() + buffer_front = self.buffer.remove() + self.stack.push(buffer_front) + self.transitions.append( + Transition(type='right-arc', stack_top=stack_top, buffer_front=buffer_front, deprel=deprel) + ) + self.nodes_assigned.append(buffer_front.ID) + self.decoded_nodes[buffer_front.ID - 1].HEAD = stack_top.ID + self.decoded_nodes[buffer_front.ID - 1].DEPREL = deprel + + # get next states + stack_top = self.stack.get() + buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(len(self.sentence), self.eos) + return stack_top, buffer_front + + + def shift(self, deprel): + front_item = self.buffer.remove() + stack_top = self.stack.get() + self.stack.push(front_item) + self.transitions.append( + Transition(type='shift', stack_top=stack_top, buffer_front=front_item, deprel=deprel)) + # get next states + stack_top = self.stack.get() + buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(len(self.sentence), self.eos) + return stack_top, buffer_front + + def reduce(self, deprel): + stack_top = self.stack.get() + try: + buffer_front = self.buffer.get() + except IndexError: + buffer_front = Node.create_eos(len(self.sentence), self.eos) + self.stack.pop() + self.transitions.append( + Transition(type='reduce', stack_top=stack_top, buffer_front=buffer_front, deprel=deprel) + ) + # get next states + stack_top = self.stack.get() + buffer_front = self.buffer.get() if len(self.buffer) > 0 else Node.create_eos(len(self.sentence), self.eos) + return stack_top, buffer_front + + def apply_transition(self, transitions: List[str], deprel: str): + stack_top = self.stack.get() + try: + buffer_front = self.buffer.get() + except IndexError: + if stack_top.ID in self.nodes_assigned: + self.reduce(deprel) + return None + + for transition in transitions: + if (transition == 'left-arc') and (stack_top.ID not in self.nodes_assigned) and (not stack_top.is_root): + return self.left_arc(deprel) + if (transition == 'right-arc') and (buffer_front.ID not in self.nodes_assigned): + return self.right_arc(deprel) + if (transition == 'reduce') and (stack_top.ID in self.nodes_assigned): + return self.reduce(deprel) + return self.shift(deprel) + + def apply(self, transition: str, deprel: str): + if transition == 'left-arc': + return self.left_arc(deprel) + if transition == 'right-arc': + return self.right_arc(deprel) + if transition == 'reduce': + return self.reduce(deprel) + if transition == 'shift': + return self.shift(deprel) + + def decode_sentence(self, transitions: List[List[str]], deprels: List[str]): + for transition_ops, deprel in zip(transitions, deprels): + info = self.apply_transition(transition_ops, deprel) + if info is None: + break + self.postprocess() + return self.decoded_nodes + + def postprocess(self): + # check if there are more than one node with root head + roots = sum([node.HEAD == 0 for node in self.decoded_nodes]) + if roots > 1: + # get leftmost root + for node in self.decoded_nodes: + if node.HEAD == 0: + root = node.ID + for node in self.decoded_nodes[root:]: + if node.HEAD == 0: + node.HEAD = root diff --git a/tania_scripts/supar/models/dep/eager/oracle/buffer.py b/tania_scripts/supar/models/dep/eager/oracle/buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..9a1f5d436d970327abef248843db0fb4890cbeeb --- /dev/null +++ b/tania_scripts/supar/models/dep/eager/oracle/buffer.py @@ -0,0 +1,24 @@ +from typing import Optional, List +from supar.models.dep.eager.oracle.node import Node + +class Buffer: + def __init__(self, items: Optional[List[Node]]): + if items: + self.items = items + else: + self.items = [] + + def get(self) -> Node: + return self.items[0] + + def remove(self) -> Node: + return self.items.pop(0) + + def append(self, item: Node): + self.items.append(item) + + def __len__(self): + return len(self.items) + + def __repr__(self): + return str(self.items) \ No newline at end of file diff --git a/tania_scripts/supar/models/dep/eager/oracle/dependency.py b/tania_scripts/supar/models/dep/eager/oracle/dependency.py new file mode 100644 index 0000000000000000000000000000000000000000..a4825ba75f62256dd3038c43d96578d4fd6c99c3 --- /dev/null +++ b/tania_scripts/supar/models/dep/eager/oracle/dependency.py @@ -0,0 +1,30 @@ +from typing import Optional, List +from supar.models.dep.eager.oracle.node import Node + +class Dependency: + def __init__(self, dependent_id: int, head_id: str, deprel: str): + self.dependent_id = dependent_id + self.head_id = head_id + self.deprel = deprel + + def __repr__(self): + return self.toconll() + + def toconll(self): + return '\t'.join([ + str(self.dependent_id), str(self.head_id), self.deprel + ]) + +class Transition: + def __init__(self, type: str, stack_top: Node, buffer_front: Node, deprel: str): + self.type = type + self.stack_top = stack_top + self.buffer_front = buffer_front + self.deprel = deprel + + def __repr__(self): + return '\t'.join((str(self.stack_top.FORM), str(self.buffer_front.FORM), self.type, self.deprel)) + + + + diff --git a/tania_scripts/supar/models/dep/eager/oracle/node.py b/tania_scripts/supar/models/dep/eager/oracle/node.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa07a99a616853235e6b1d2dc18f6709319e2a6 --- /dev/null +++ b/tania_scripts/supar/models/dep/eager/oracle/node.py @@ -0,0 +1,74 @@ +from typing import List + +class Node: + + def __init__(self, ID: int, FORM: str, UPOS: str, HEAD: int, DEPREL: str, is_root: bool = False): + self.ID = ID + self.FORM = FORM + self.UPOS = UPOS + self.HEAD = HEAD + self.DEPREL = DEPREL + self.is_root = is_root + + def __str__(self): + return f'Node(ID={self.ID}, FORM={self.FORM}, UPOS={self.UPOS}, HEAD={self.HEAD})' + + def __repr__(self): + return f'Node(ID={self.ID}, FORM={self.FORM}, UPOS={self.UPOS}, HEAD={self.HEAD})' + + @classmethod + def from_conllu(cls, conll: str): + ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = conll.split('\t') + if HEAD == '_': + return Node(int(ID), FORM, UPOS, HEAD, DEPREL) + else: + return Node(int(ID), FORM, UPOS, int(HEAD), DEPREL) + + @classmethod + def create_root(cls, token: str): + return Node(0, token, token, 0, token, is_root=True) + + @classmethod + def create_eos(cls, position: int, token: str): + return Node(position, token, token, 0, token, is_root=False) + + def coverage(self) -> range: + limits = sorted([self.ID, self.HEAD]) + return range(*limits) + + def isprojective(heads: List[int]): + pairs = [(h, d) for d, h in enumerate(heads, 1) if h >= 0] + for i, (hi, di) in enumerate(pairs): + for hj, dj in pairs[i + 1:]: + (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj]) + if li <= hj <= ri and hi == dj: + print('1') + return False + if lj <= hi <= rj and hj == di: + print('2') + return False + if (li < lj < ri or li < rj < ri) and (li - lj) * (ri - rj) > 0: + print('3') + print(di, hi, dj, hj) + return False + return True + +def isprojective(heads: List[int]): + pairs = [(h, d) for d, h in enumerate(heads, 1) if h >= 0] + for i, (hi, di) in enumerate(pairs): + for hj, dj in pairs[i + 1:]: + (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj]) + if li <= hj <= ri and hi == dj: + print('1') + return False + if lj <= hi <= rj and hj == di: + print('2') + return False + if (li < lj < ri or li < rj < ri) and (li - lj) * (ri - rj) > 0: + print('3') + print(di, hi, dj, hj) + return False + return True + + + diff --git a/tania_scripts/supar/models/dep/eager/oracle/stack.py b/tania_scripts/supar/models/dep/eager/oracle/stack.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5545a54e71d347c514d385848a08889cade629 --- /dev/null +++ b/tania_scripts/supar/models/dep/eager/oracle/stack.py @@ -0,0 +1,24 @@ +from typing import Union, List, Optional +from supar.models.dep.eager.oracle.node import Node + +class Stack: + def __init__(self, items: Optional[List[Node]] = None): + if items: + self.items = items + else: + self.items = [] + + def pop(self) -> Node: + return self.items.pop(-1) + + def push(self, item: Node): + self.items.append(item) + + def get(self) -> Node: + return self.items[-1] + + def __repr__(self): + return repr(self.items) + + + diff --git a/tania_scripts/supar/models/dep/eager/parser.py b/tania_scripts/supar/models/dep/eager/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f4857ded7c9bc3b61042ef60dbd33e9f4800972f --- /dev/null +++ b/tania_scripts/supar/models/dep/eager/parser.py @@ -0,0 +1,380 @@ +import os +from supar.models.dep.eager.oracle.node import Node + +import torch +from supar.models.dep.eager.model import ArcEagerDependencyModel +from supar.parser import Parser +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, PAD, UNK, EOS +from supar.utils.field import Field, RawField, SubwordField +from supar.utils.fn import ispunct +from supar.utils.logging import get_logger +from supar.utils.metric import AttachmentMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.utils.transform import Batch +from supar.models.dep.eager.transform import ArcEagerTransform +from supar.models.dep.eager.oracle.arceager import ArcEagerDecoder +from itertools import groupby + + +logger = get_logger(__name__) +from typing import Tuple, List, Union + +def consecutive_duplicate_spans(L): + new_list = [] + for group in groupby(L): + f = list(group[1]) + if len(f) > 1: + new_el = f[0].replace('-arc', '') + "*" + new_list.append(new_el) + elif len(f) == 1: + new_el = f[0].replace('-arc', '') + new_list.append(new_el) + return new_list + + +class ArcEagerDependencyParser(Parser): + MODEL = ArcEagerDependencyModel + NAME = 'arceager-dependency' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.FORM = self.transform.FORM + self.STACK_TOP = self.transform.STACK_TOP + self.BUFFER_FRONT = self.transform.BUFFER_FRONT + self.TRANSITION, self.TREL = self.transform.TRANSITION, self.transform.TREL + self.TAG = self.transform.UPOS + self.HEAD = self.transform.HEAD + + def train_step(self, batch: Batch) -> torch.Tensor: + words, texts, *feats, tags, _, _, stack_top, buffer_front, transitions, trels = batch + #print('dep eager parser feats', len(*feats)) + tmask = self.get_padding_mask(stack_top) + + # pad stack_top and buffer_front vectors: note that padding index must be the length of the sequence + pad_indices = (batch.lens - 1 - self.args.delay).tolist() + stack_top, buffer_front = self.pad_tensor(stack_top, pad_indices), self.pad_tensor(buffer_front, pad_indices) + # forward pass + # stack_top: torch.Tensor ~ [batch_size, pad(tr_len), n_transitions] + # buffer_front: torch.Tensor ~ [batch_size, pad(tr_len), n_trels] + s_transition, s_trel, s_tag, qloss = self.model(words, stack_top, buffer_front, feats) + + # compute loss + smask = batch.mask[:, (2+self.args.delay):] + loss = self.model.loss(s_transition, s_trel, s_tag, transitions, trels, tags, smask, tmask, self.TRANSITION) + qloss + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> AttachmentMetric: + words, texts, *feats, tags, heads, deprels, stack_top, buffer_front, transitions, trels = batch + transition_mask = self.get_padding_mask(stack_top) + + + # obtain transition loss + stack_top, buffer_front = \ + self.pad_tensor(stack_top, (batch.lens - 1 - self.args.delay).tolist()), \ + self.pad_tensor(buffer_front, (batch.lens - 1 - self.args.delay).tolist()) + s_transition, s_trel, s_tag, qloss = self.model(words, stack_top, buffer_front, feats.copy()) + smask = batch.mask[:, (2+self.args.delay):] + loss = self.model.loss(s_transition, s_trel, s_tag, transitions, trels, tags, smask, transition_mask, self.TRANSITION) + qloss + + # obtain indices of deprels from TREL field + batch_size = words.shape[0] + deprels = [self.TREL.vocab[deprels[b]] for b in range(batch_size)] + + # create decoders + lens = list(map(len, texts)) + sentences = list() + for b in range(batch_size): + sentences.append( + [ + Node(ID=i + 1, FORM=form, UPOS='', HEAD=head, DEPREL=deprel) for i, (form, head, deprel) in \ + enumerate(zip(texts[b], heads[b, :lens[b]].tolist(), deprels[b])) + ] + ) + decoders = list(map( + lambda sentence: ArcEagerDecoder(sentence=sentence, bos='', eos='', unk=self.transform.TREL.unk_index), + sentences + )) + + # compute oracle simulation for all elements in batch + head_preds, deprel_preds, *_ = self.oracle_decoding(decoders, words, feats) + head_preds, deprel_preds = self.pad_tensor(head_preds), self.pad_tensor(deprel_preds) + deprels = self.pad_tensor(deprels) + + seq_mask = batch.mask[:, (2 + self.args.delay):] + return AttachmentMetric(loss, (head_preds, deprel_preds), (heads, deprels), seq_mask) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, texts, *feats = batch + #print('SENTENCE: ', ' '.join(list(texts[0]))) + #print() + lens = (batch.lens - 2 - self.args.delay).tolist() + + + batch_size = words.shape[0] + # create decoders + sentences = list() + for b in range(batch_size): + sentences.append( + [ + Node(ID=i + 1, FORM='', UPOS='', HEAD=None, DEPREL=None) for i in range(lens[b]) + ] + ) + decoders = list(map( + lambda sentence: ArcEagerDecoder(sentence=sentence, bos='', eos='', unk=self.transform.TREL.unk_index), + sentences + )) + # compute oracle simulation for all elements in batch + head_preds, deprel_preds, stack_list, buffer_list,actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, act_dict = self.oracle_decoding(decoders, words, feats) + batch.heads = head_preds + batch.rels = deprel_preds + return batch, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, list(texts[0]), act_dict + + def get_padding_mask(self, tensor_list: List[torch.Tensor]) -> torch.Tensor: + """ + From a list of tensors of different lengths, creates a padding mask where False values indicates + padding tokens. True otherwise. + Args: + tensor_list: List of tensors. + Returns: torch.Tensor ~ [len(tensor_list), max(lenghts)] + + """ + lens = list(map(len, tensor_list)) + max_len = max(lens) + return torch.tensor([[True] * length + [False] * (max_len - length) for length in lens]).to(self.model.device) + + def pad_tensor( + self, + tensor_list: Union[List[torch.Tensor], List[List[int]]], + pad_index: Union[int, List[int]] = 0 + ): + """ + Applies padding to a list of tensors or list of lists. + Args: + tensor_list: List of tensors or list of lists. + pad_index: Index used for padding or list of indices used for padding for each item of tensor_list. + Returns: torch.Tensor ~ [len(tensor_list), max(lengths)] + """ + max_length = max(map(len, tensor_list)) + if isinstance(pad_index, int): + if isinstance(tensor_list[0], list): + return torch.tensor( + [tensor + [pad_index] * (max_length - len(tensor)) for tensor in tensor_list]).to(self.model.device) + else: + return torch.tensor( + [tensor.tolist() + [pad_index] * (max_length - len(tensor)) for tensor in tensor_list]).to(self.model.device) + else: + pad_indexes = pad_index + if isinstance(tensor_list[0], list): + return torch.tensor( + [tensor + [pad_index] * (max_length - len(tensor)) for tensor, pad_index in + zip(tensor_list, pad_indexes)]).to(self.model.device) + else: + return torch.tensor( + [tensor.tolist() + [pad_index] * (max_length - len(tensor)) for tensor, pad_index in + zip(tensor_list, pad_indexes)]).to(self.model.device) + + def get_text_mask(self, batch): + text_lens = (batch.lens - 2 - self.args.delay).tolist() + mask = batch.mask + mask[:, 0] = 0 # remove bos token + for i, text_len in enumerate(text_lens): + mask[i, (1 + text_len):] = 0 + return mask + + def oracle_decoding(self, decoders: List[ArcEagerDecoder], words: torch.Tensor, feats: List[torch.Tensor]) -> Tuple[ + List[List[int]]]: + """ + Implements Arc-Eager decoding. Using words indices, creates the initial state of the Arc-Eager oracle + and predicts each (transition, trel) with the TransitionDependencyModel. + Args: + decoders: List[ArcEagerDecoder] ~ batch_size + words: torch.Tensor ~ [batch_size, seq_len] + feats: List[torch.Tensor ~ [batch_size, seq_len, feat_embed]] ~ n_feat + + + Returns: head_preds, deprel_preds + head_preds: List[List[int] ~ sen_len] ~ batch_size: Head values for each sentence in batch. + deprel_preds: List[List[int] ~ sen_len] ~ batch_size: Indices of dependency relations for each sentence in batch. + """ + # create a mask vector to filter those decoders that achieved the final state + compute = [True for _ in range(len(decoders))] + batch_size = len(decoders) + #print('ORACLE DECO TAG VOCAB', self.TAG.vocab.items()) + + exclude = self.TREL.vocab[['<reduce>', '<shift>']] + + # obtain word representations of the encoder + x, s_tag, _ = self.model.encoder_forward(words, feats) + transition_stags = self.model.decode_stag(s_tag) + + + stack_top = [torch.tensor([decoders[b].stack.get().ID]).reshape(1) for b in range(batch_size)] + buffer_front = [torch.tensor([decoders[b].buffer.get().ID]).reshape(1) for b in range(batch_size)] + counter = 0 + + stack_list = [] + buffer_list = [] + actions_list = [] + + #print('stack_top', stack_top[-1].item()) + #print('buffer_front', buffer_front[-1].item()) + stack_list.append(stack_top[-1].item()) + buffer_list.append(buffer_front[-1].item()) + + while any(compute): + s_transition, s_trel = self.model.decoder_forward(x, torch.stack(stack_top), torch.stack(buffer_front)) + transition_preds, trel_preds = self.model.decode(s_transition, s_trel, exclude) + transition_preds, trel_preds = transition_preds[:, counter, :], trel_preds[:, counter] + transition_preds = [self.TRANSITION.vocab[i.tolist()] for i in + transition_preds.reshape(batch_size, self.args.n_transitions)] + for b, decoder in enumerate(decoders): + #print('209', transition_preds[b][0]) + actions_list.append(transition_preds[b][0]) + if not compute[b]: + stop, bfront = stack_top[b][-1].item(), buffer_front[b][-1].item() + + else: + result = decoder.apply_transition(transition_preds[b], trel_preds[b].item()) + if result is None: + stop, bfront = stack_top[b][-1].item(), buffer_front[b][-1].item() + #print("2019 stop, bfront", stop, bfront) + + + compute[b] = False + else: + #print() + stop, bfront = result[0].ID, result[1].ID + stack_list.append(stop) + buffer_list.append(bfront) + #print("225 stop, bfront", stop, bfront) + + + stack_top[b] = torch.concat([stack_top[b], torch.tensor([stop])]) + buffer_front[b] = torch.concat([buffer_front[b], torch.tensor([bfront])]) + counter += 1 + + head_preds = [[node.HEAD for node in decoder.decoded_nodes] for decoder in decoders] + deprel_preds = [[node.DEPREL for node in decoder.decoded_nodes] for decoder in decoders] + deprel_preds_decoded = [[self.TREL.vocab[i] for i in dep_pre] for dep_pre in deprel_preds] + pos_preds = [[i.item() for i in trans_stag] for trans_stag in transition_stags] + pos_preds_decoded = [[self.TAG.vocab[i.item()] for i in trans_stag] for trans_stag in transition_stags] + form_preds = [[node.FORM for node in decoder.decoded_nodes] for decoder in decoders] + #print(len(stack_list), stack_list) + #print(len(buffer_list), buffer_list) + #print(len(actions_list), actions_list) + #assert len(stack_list) == len(buffer_list) == len(actions_list) + + """ + print() + + print('stack list: ', stack_list) + print('buffer list: ', buffer_list) + print('actions list: ', actions_list) + print() + print('head_preds: ', head_preds) + print('deprel_preds: ', deprel_preds, deprel_preds_decoded) + print('pos_preds: ', pos_preds, pos_preds_decoded) + print() + """ + + act_dict = {} + for i, j in zip(buffer_list, actions_list): + act_dict.setdefault(i, []).append(j) + + act_dict_list = [] + for _,vel in act_dict.items(): + new_list = consecutive_duplicate_spans(vel) + act_dict_list.append(" ".join(new_list)) + + #print('macro actions: ', act_dict_list) + + #assert len(act_dict_list) == len(head_preds[0]) + + + return head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, act_dict + + @classmethod + def build(cls, path, min_freq=1, fix_len=20, **kwargs): + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + + # ------------------------------- source fields ------------------------------- + WORD, TAG, CHAR = None, None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + pad_token = t.pad if t.pad else PAD + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, fix_len=args.fix_len, tokenize=t, delay=args.delay) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, eos=EOS, lower=True, delay=args.delay) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, eos=EOS, fix_len=args.fix_len, delay=args.delay) + if 'tag' in args.feat: + TAG = Field('tags') + TAG = Field('tags') + TEXT = RawField('texts') + STACK_TOP = RawField('stack_top', fn=lambda x: torch.tensor(x)) + BUFFER_FRONT = RawField('buffer_front', fn=lambda x: torch.tensor(x)) + + # ------------------------------- target fields ------------------------------- + TRANSITION = Field('transition') + TREL = Field('trels') + HEAD = Field('heads', use_vocab=False) + DEPREL = RawField('rels') + + transform = ArcEagerTransform( + FORM=(WORD, TEXT, CHAR), UPOS=TAG, HEAD=HEAD, DEPREL=DEPREL, + STACK_TOP=STACK_TOP, BUFFER_FRONT=BUFFER_FRONT, TRANSITION=TRANSITION, TREL=TREL, + ) + train = Dataset(transform, args.train, **args) + + if args.encoder != 'bert': + print('HOLY SH dep eager parser') + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), + lambda x: x / torch.std(x)) + if TAG: + TAG.build(train) + if CHAR: + + CHAR.build(train) + TAG.build(train) + TREL.build(train) + TRANSITION.build(train) + + print('TAG VOCAB', TAG.vocab.items()) + #print('CHAR VOCAB', CHAR.vocab.items()) + + + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_transitions': len(TRANSITION.vocab), + 'n_trels': len(TREL.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index + }) + + + logger.info(f"{transform}") + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser diff --git a/tania_scripts/supar/models/dep/eager/transform.py b/tania_scripts/supar/models/dep/eager/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..be5c883b8abba8f07d14ca75ddbfc8adaf9fa2f8 --- /dev/null +++ b/tania_scripts/supar/models/dep/eager/transform.py @@ -0,0 +1,176 @@ +from supar.utils.transform import Transform, Sentence +from supar.utils.field import Field +from typing import Iterable, Union, Optional, List, Tuple +from supar.models.dep.eager.oracle.arceager import ArcEagerEncoder +from supar.models.dep.eager.oracle.node import Node +import os +from io import StringIO + + +class ArcEagerTransform(Transform): + + fields = ['ID', 'FORM', 'LEMMA', 'UPOS', 'XPOS', 'FEATS', 'HEAD', 'DEPREL', 'DEPS', 'MISC', 'STACK_TOP', 'BUFFER_FRONT', 'TRANSITION', 'TREL'] + + def __init__( + self, + ID: Optional[Union[Field, Iterable[Field]]] = None, + FORM: Optional[Union[Field, Iterable[Field]]] = None, + LEMMA: Optional[Union[Field, Iterable[Field]]] = None, + UPOS: Optional[Union[Field, Iterable[Field]]] = None, + XPOS: Optional[Union[Field, Iterable[Field]]] = None, + FEATS: Optional[Union[Field, Iterable[Field]]] = None, + HEAD: Optional[Union[Field, Iterable[Field]]] = None, + DEPREL: Optional[Union[Field, Iterable[Field]]] = None, + DEPS: Optional[Union[Field, Iterable[Field]]] = None, + MISC: Optional[Union[Field, Iterable[Field]]] = None, + STACK_TOP: Optional[Union[Field, Iterable[Field]]] = None, + BUFFER_FRONT: Optional[Union[Field, Iterable[Field]]] = None, + TRANSITION: Optional[Union[Field, Iterable[Field]]] = None, + TREL: Optional[Union[Field, Iterable[Field]]] = None + ): + super().__init__() + + self.ID = ID + self.FORM = FORM + self.LEMMA = LEMMA + self.UPOS = UPOS + self.XPOS = XPOS + self.FEATS = FEATS + self.HEAD = HEAD + self.DEPREL = DEPREL + self.DEPS = DEPS + self.MISC = MISC + self.STACK_TOP = STACK_TOP + self.BUFFER_FRONT = BUFFER_FRONT + self.TRANSITION = TRANSITION + self.TREL = TREL + + @property + def src(self): + return self.FORM, self.LEMMA, self.UPOS, self.XPOS, self.FEATS + + @classmethod + def toconll(cls, tokens: List[Union[str, Tuple]]) -> str: + r""" + Converts a list of tokens to a string in CoNLL-X format with missing fields filled with underscores. + + Args: + tokens (List[Union[str, Tuple]]): + This can be either a list of words, word/pos pairs or word/lemma/pos triples. + + Returns: + A string in CoNLL-X format. + + Examples: + >>> print(CoNLL.toconll(['She', 'enjoys', 'playing', 'tennis', '.'])) + 1 She _ _ _ _ _ _ _ _ + 2 enjoys _ _ _ _ _ _ _ _ + 3 playing _ _ _ _ _ _ _ _ + 4 tennis _ _ _ _ _ _ _ _ + 5 . _ _ _ _ _ _ _ _ + + >>> print(CoNLL.toconll([('She', 'she', 'PRP'), + ('enjoys', 'enjoy', 'VBZ'), + ('playing', 'play', 'VBG'), + ('tennis', 'tennis', 'NN'), + ('.', '_', '.')])) + 1 She she PRP _ _ _ _ _ _ + 2 enjoys enjoy VBZ _ _ _ _ _ _ + 3 playing play VBG _ _ _ _ _ _ + 4 tennis tennis NN _ _ _ _ _ _ + 5 . _ . _ _ _ _ _ _ + + """ + if isinstance(tokens[0], str): + s = '\n'.join([f"{i}\t{word}\t" + '\t'.join(['_'] * 8) + for i, word in enumerate(tokens, 1)]) + elif len(tokens[0]) == 2: + s = '\n'.join([f"{i}\t{word}\t_\t{tag}\t" + '\t'.join(['_'] * 6) + for i, (word, tag) in enumerate(tokens, 1)]) + elif len(tokens[0]) == 3: + s = '\n'.join([f"{i}\t{word}\t{lemma}\t{tag}\t" + '\t'.join(['_'] * 6) + for i, (word, lemma, tag) in enumerate(tokens, 1)]) + elif len('85!!', tokens[0]) == 10: + s = '\n'.join([f"{i}\t{word}\t{lemma}\t{tag}\t{xpos}\t{upos}\t{head}\t{rel}\t{comm}\t{morph}" + for (i, word, lemma, tag, xpos, upos, head, rel, comm, morph) in tokens]) + else: + raise RuntimeError(f"Invalid sequence {tokens}. Only list of str or list of word/pos/lemma tuples are support.") + return s + '\n' + + @property + def tgt(self): + return self.HEAD, self.DEPREL, self.DEPS, self.MISC, self.STACK_TOP, self.BUFFER_FRONT, self.TRANSITION, self.TREL + + def load( + self, + data: Union[str, Iterable], + lang: Optional[str] = None, + **kwargs + ): + if isinstance(data, str) and os.path.exists(data): + if os.path.isfile(data): + lines = open(data) + if os.path.isdir(data): + lines = [] + for filepath in data: + if filepath.endswith('.conllu'): + l = open(filepath) + lines.append(l) + else: + if lang is not None: + #data = [tokenizer(s) for s in ([data] if isinstance(data, str) else data)] + data = [data.split() if isinstance(data, str) else data] + else: + data = [data] if isinstance(data[0], str) else data + lines = (i for s in data for i in StringIO(self.toconll(s) + '\n')) + + index, sentence = 0, [] + for line in lines: + line = line.strip() + if len(line) == 0: + sentence = ArcEagerSentence(self, sentence, index) + yield sentence + index += 1 + sentence = [] + else: + sentence.append(line) + + + +class ArcEagerSentence(Sentence): + def __init__(self, transform: ArcEagerEncoder, lines: List[str], index: Optional[int] = None): + super().__init__(transform, index) + self.values = list() + self.annotations = dict() + + for i, line in enumerate(lines): + value = line.split('\t') + if value[0].startswith('#') or not value[0].isdigit(): + self.annotations[-i-1] = line + else: + self.annotations[len(self.values)] = line + self.values.append(value) + + nodes = [Node.from_conllu('\t'.join(value)) for value in self.values] + + + algorithm = ArcEagerEncoder(bos=transform.FORM[0].bos, eos=transform.FORM[0].eos) + transitions = algorithm.encode(nodes.copy()) + stack_top, buffer_front, transition, trel = zip( + *[(transition.stack_top.ID, transition.buffer_front.ID, transition.type, transition.deprel) + for transition in transitions]) + + self.values = list(zip(*self.values)) + if self.values[6][0] != '_': + self.values[6] = tuple(map(int, self.values[6])) + self.values += [stack_top, buffer_front, transition, trel] + + + def __repr__(self): + # cover the raw lines + #print('µµµµµµµµ', self.values[:-4]) + merged = {**self.annotations, + **{i: '\t'.join(map(str, line)) + for i, line in enumerate(zip(*self.values[:-4]))}} + + return '\n'.join(merged.values()) + '\n' \ No newline at end of file diff --git a/tania_scripts/supar/models/dep/sl/__init__.py b/tania_scripts/supar/models/dep/sl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e75ee7b541dea11bd85069b1c4cb47e8dce78c --- /dev/null +++ b/tania_scripts/supar/models/dep/sl/__init__.py @@ -0,0 +1,2 @@ +from .model import SLDependencyModel +from .parser import SLDependencyParser \ No newline at end of file diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e84041e18d76d00c1a56d47aa5947ed6eabe7f48 Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7da5b0abda7bcaada2f81498ced6765d9b1bb222 Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eca97c4fc223779fc95ca035a2bb17aa33cb321 Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f0f4cad3dcc22e7872d1c2ee274f49f8582da16 Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95e1cddfd1b7d2af2f9bbf1573def416297a278b Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b16ba48675cbab2d7e21471d46322179b0928fb Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df4678723720a507453f379f951e9dbd859a3a2 Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..027f048a42a233403d83761297506d43e3d196b9 Binary files /dev/null and b/tania_scripts/supar/models/dep/sl/__pycache__/transform.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/sl/model.py b/tania_scripts/supar/models/dep/sl/model.py new file mode 100644 index 0000000000000000000000000000000000000000..79e5ec7e5a5ac1c99fbe2218d3f4a46aee6069ac --- /dev/null +++ b/tania_scripts/supar/models/dep/sl/model.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from supar.model import Model +from supar.modules import MLP, DecoderLSTM +from supar.utils import Config +from typing import Tuple, List, Union + +class SLDependencyModel(Model): + def __init__(self, + n_words: int, + n_labels: Union[Tuple[int], int], + n_rels: int, + n_tags: int = None, + n_chars: int = None, + encoder: str ='lstm', + feat: List[str] = [], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, False), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_arc_mlp=500, + n_rel_mlp=100, + mlp_dropout=.33, + scale=0, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + # create decoder + self.label_decoder, self.rel_decoder = None, None + if self.args.decoder == 'lstm': + decoder = lambda out_dim: DecoderLSTM(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + input_size=self.args.n_encoder_hidden, hidden_size=self.args.n_encoder_hidden, + num_layers=self.args.n_decoder_layers, dropout=mlp_dropout, + output_size=out_dim) + else: + decoder = lambda out_dim: MLP( + n_in=self.args.n_encoder_hidden, n_out=out_dim, dropout=mlp_dropout + ) + + if self.args.codes == '2p': + self.label_decoder1 = decoder(self.args.n_labels[0]) + self.label_decoder2 = decoder(self.args.n_labels[1]) + self.label_decoder = lambda x: (self.label_decoder1(x), self.label_decoder2(x)) + else: + self.label_decoder = decoder(self.args.n_labels) + + self.rel_decoder = decoder(self.args.n_rels) + + # create delay projection + if self.args.delay != 0: + self.delay_proj = MLP(n_in=self.args.n_encoder_hidden * (self.args.delay + 1), + n_out=self.args.n_encoder_hidden, dropout=mlp_dropout) + + # create PoS tagger + if self.args.encoder == 'lstm': + self.pos_tagger = DecoderLSTM( + input_size=self.args.n_encoder_hidden, hidden_size=self.args.n_encoder_hidden, + output_size=self.args.n_tags, num_layers=1, dropout=mlp_dropout, device=self.device) + else: + self.pos_tagger = nn.Identity() + + self.criterion = nn.CrossEntropyLoss() + + def forward(self, words: torch.Tensor, feats: List[torch.Tensor] = None) -> Tuple[torch.Tensor]: + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + s_label (~Union[torch.Tensor, Tuple[torch.Tensor]]): ``[batch_size, seq_len, n_labels]`` + Tensor or 2-dimensional tensor tuple (if 2-planar bracketing coding is being used) which holds the + scores of all possible labels. + s_rel (~torch.Tensor): ``[batch_size, seq_len, n_rels]`` + Holds scores of all possible dependency relations. + s_tag (~torch.Tensor): ` [batch_size, seq_len, n_tags]`` + Holds scores of all possible tags for each word. + qloss (~torch.Tensor): + Vector quantization loss. + """ + + # x ~ [batch_size, bos + pad(seq_len) + delay, n_encoder_hidden] + x = self.encode(words, feats) + x = x[:, 1:, :] # remove BoS token + + # s_tag ~ [batch_size, pad(seq_len), n_tags] + s_tag = self.pos_tagger(x if self.args.delay == 0 else x[:, :-self.args.delay, :]) + + # map or concatenate delayed representations + if self.args.delay != 0: + x = torch.cat([x[:, i:(x.shape[1] - self.args.delay + i), :] for i in range(self.args.delay+1)], dim=2) + x = self.delay_proj(x) + + # x ~ [batch_size, pad(seq_len), n_encoder_hidden] + batch_size, pad_seq_len, _ = x.shape + + # pass through vector quantization module + x, qloss = self.vq_forward(x) + + # make predictions of labels/relations + s_label = self.label_decoder(x) + s_rel = self.rel_decoder(x) + + return s_label, s_rel, s_tag, qloss + + def loss( + self, + s_label: Union[Tuple[torch.Tensor], torch.Tensor], + s_rel: torch.Tensor, + s_tag: torch.Tensor, + labels: Union[Tuple[torch.Tensor], torch.Tensor], + rels: torch.Tensor, + tags: torch.Tensor, + mask: torch.Tensor + ) -> torch.Tensor: + + loss = self.criterion(s_label[mask], labels[mask]) if self.args.codes != '2p' else sum(self.criterion(scores[mask], golds[mask]) for scores, golds in zip(s_label, labels)) + + loss += self.criterion(s_rel[mask], rels[mask]) + + if self.args.encoder == 'lstm': + loss += self.criterion(s_tag[mask], tags[mask]) + return loss + + def decode(self, s_label: Union[Tuple[torch.Tensor], torch.Tensor], s_rel: torch.Tensor, s_tag: torch.Tensor, + mask: torch.Tensor): + label_preds = s_label.argmax(-1) if self.args.codes != '2p' else tuple(map(lambda x: x.argmax(-1), s_label)) + rel_preds = s_rel.argmax(-1) + tag_preds = s_tag.argmax(-1) if self.args.encoder == 'lstm' else None + return label_preds, rel_preds, tag_preds diff --git a/tania_scripts/supar/models/dep/sl/parser.py b/tania_scripts/supar/models/dep/sl/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..af95475824a3b5719b4ff13ab955cdd51b5493bf --- /dev/null +++ b/tania_scripts/supar/models/dep/sl/parser.py @@ -0,0 +1,323 @@ +# -*- coding: utf-8 -*- + +import os, re +from typing import Iterable, Union, List + +import torch +from supar.models.dep.sl.model import SLDependencyModel +from supar.parser import Parser +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, PAD, UNK +from supar.utils.field import Field, RawField, SubwordField +from supar.utils.logging import get_logger +from supar.utils.metric import AttachmentMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.utils.transform import Batch +from supar.models.dep.sl.transform import SLDependency +from supar.codelin import get_dep_encoder, LinearizedTree, D_Label +from supar.codelin.utils.constants import D_ROOT_HEAD +from supar.codelin import LABEL_SEPARATOR +logger = get_logger(__name__) +from typing import Tuple, List, Union + + +NONE = '<none>' +OPTIONS = [r'>\*', r'<\*', r'/\*', r'\\\*'] + + +def split_planes(labels: Tuple[str], plane: int) -> Tuple[str]: + + split = lambda label: min(match.span()[0] if match is not None else len(label) for match in map(lambda x: re.search(x, label), OPTIONS)) + splits = list(map(split, labels)) + if plane == 0: + return tuple(label[:i] for label, i in zip(labels, splits)) + else: + return tuple(label[i:] if i < len(label) else NONE for label, i in zip(labels, splits)) + +class SLDependencyParser(Parser): + + NAME = 'SLDependencyParser' + MODEL = SLDependencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.TAG = self.transform.UPOS + self.LABEL, self.DEPREL = self.transform.LABEL, self.transform.DEPREL + self.encoder = get_dep_encoder(self.args.codes, LABEL_SEPARATOR) + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 1000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + tree: bool = False, + proj: bool = False, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def few_shot( + self, + train: str, + dev: Union[str, Iterable], + test: Union[str, Iterable], + n_samples: int, + epochs: int = 2, + batch_size: int = 50 + ) -> None: + return super().few_shot(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + tree: bool = True, + proj: bool = False, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 500, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + tree: bool = True, + proj: bool = False, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + if self.args.encoder == 'lstm': + words, texts, chars, tags, _, rels, *labels = batch + feats = [chars] + else: + words, texts, tags, _, rels, *labels = batch + feats = [] + labels = labels[0] if len(labels) == 1 else labels + mask = batch.mask[:, (1+self.args.delay):] + + # forward pass + s_label, s_rel, s_tag, qloss = self.model(words, feats) + + # compute loss + loss = self.model.loss(s_label, s_rel, s_tag, labels, rels, tags, mask) + qloss + + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> AttachmentMetric: + if self.args.encoder == 'lstm': + words, texts, chars, tags, heads, rels, *labels = batch + feats = [chars] + else: + words, texts, tags, heads, rels, *labels = batch + feats = [] + labels = labels[0] if len(labels) == 1 else labels + mask = batch.mask[:, (1+self.args.delay):] + lens = (batch.lens - 1 - self.args.delay).tolist() + + # forward pass + s_label, s_rel, s_tag, qloss = self.model(words, feats) + + # compute loss and decode + loss = self.model.loss(s_label, s_rel, s_tag, labels, rels, tags, mask) + qloss + label_preds, rel_preds, tag_preds = self.model.decode(s_label, s_rel, s_tag, mask) + + # obtain original label and deprel strings to compute decoding + if self.args.codes == '2p': + label_preds = [ + (self.LABEL[0].vocab[l1.tolist()], self.LABEL[1].vocab[l2.tolist()]) + for l1, l2 in zip(*map(lambda x: x[mask].split(lens), label_preds)) + ] + label_preds = [ + list(map(lambda x: x[0] + (x[1] if x[1] != NONE else ''), zip(*label_pred))) + for label_pred in label_preds + ] + else: + label_preds = [self.LABEL.vocab[i.tolist()] for i in label_preds[mask].split(lens)] + deprel_preds = [self.DEPREL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)] + + if self.args.encoder == 'lstm': + tag_preds = [self.TAG.vocab[i.tolist()] for i in tag_preds[mask].split(lens)] + else: + tag_preds = [self.TAG.vocab[i.tolist()] for i in tags[mask].split(lens)] + + # decode + head_preds = list() + for label_pred, deprel_pred, tag_pred, forms in zip(label_preds, deprel_preds, tag_preds, texts): + labels = [D_Label(label, deprel, self.encoder.separator) for label, deprel in zip(label_pred, deprel_pred)] + linearized_tree = LinearizedTree(list(forms), tag_pred, ['_']*len(forms), labels, 0) + + decoded_tree = self.encoder.decode(linearized_tree) + decoded_tree.postprocess_tree(D_ROOT_HEAD) + head_preds.append(torch.tensor([int(node.head) for node in decoded_tree.nodes])) + + + # resize head predictions (add padding) + resize = lambda list_of_tensors: \ + torch.stack([ + torch.concat([x, torch.zeros(mask.shape[1] - len(x))]) + for x in list_of_tensors]) + + head_preds = resize(head_preds).to(torch.int32).to(self.model.device) + + return AttachmentMetric(loss, (head_preds, rel_preds), (heads, rels), mask) + + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + if self.args.encoder == 'lstm': + words, texts, *feats, tags = batch + else: + words, texts, *feats, tags = batch + mask = batch.mask[:, (1 + self.args.delay):] + lens = (batch.lens - 1 - self.args.delay).tolist() + + # forward pass + s_label, s_rel, s_tag, qloss = self.model(words, feats) + + # compute loss and decode + label_preds, rel_preds, tag_preds = self.model.decode(s_label, s_rel, s_tag, mask) + + # obtain original label and deprel strings to compute decoding + if self.args.codes == '2p': + label_preds = [ + (self.LABEL[0].vocab[l1.tolist()], self.LABEL[1].vocab[l2.tolist()]) + for l1, l2 in zip(*map(lambda x: x[mask].split(lens), label_preds)) + ] + label_preds = [ + list(map(lambda x: x[0] + (x[1] if x[1] != NONE else ''), zip(*label_pred))) + for label_pred in label_preds + ] + else: + label_preds = [self.LABEL.vocab[i.tolist()] for i in label_preds[mask].split(lens)] + + deprel_preds = [self.DEPREL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)] + + if self.args.encoder == 'lstm': + tag_preds = [self.TAG.vocab[i.tolist()] for i in tag_preds[mask].split(lens)] + else: + tag_preds = [self.TAG.vocab[i.tolist()] for i in tags[mask].split(lens)] + + # decode + head_preds = list() + for label_pred, deprel_pred, tag_pred, forms in zip(label_preds, deprel_preds, tag_preds, texts): + labels = [D_Label(label, deprel, self.encoder.separator) for label, deprel in zip(label_pred, deprel_pred)] + linearized_tree = LinearizedTree(forms, tag_pred, ['_'] * len(forms), labels, 0) + + decoded_tree = self.encoder.decode(linearized_tree) + decoded_tree.postprocess_tree(D_ROOT_HEAD) + + head_preds.append([int(node.head) for node in decoded_tree.nodes]) + + batch.heads = head_preds + batch.rels = deprel_preds + + return batch + + @classmethod + def build(cls, path, min_freq=2, fix_len=20, **kwargs): + args = Config(**locals()) + + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + CHAR = None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + pad_token = t.pad if t.pad else PAD + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t, delay=args.delay) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True, delay=args.delay) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len, delay=args.delay) + TEXT = RawField('texts') + TAG = Field('tags') + + + if args.codes == '2p': + LABEL1 = Field('labels1', fn=lambda seq: split_planes(seq, 0)) + LABEL2 = Field('labels2', fn=lambda seq: split_planes(seq, 1)) + LABEL = (LABEL1, LABEL2) + else: + LABEL = Field('labels') + + DEPREL = Field('rels') + HEAD = Field('heads', use_vocab=False) + + transform = SLDependency( + encoder=get_dep_encoder(args.codes, LABEL_SEPARATOR), + FORM=(WORD, TEXT, CHAR), UPOS=TAG, HEAD=HEAD, DEPREL=DEPREL, LABEL=LABEL) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if CHAR: + CHAR.build(train) + TAG.build(train) + DEPREL.build(train) + + if args.codes == '2p': + LABEL[0].build(train) + LABEL[1].build(train) + else: + LABEL.build(train) + + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_labels': len(LABEL.vocab) if args.codes != '2p' else (len(LABEL[0].vocab), len(LABEL[1].vocab)), + 'n_rels': len(DEPREL.vocab), + 'n_tags': len(TAG.vocab), + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index, + 'delay': args.delay + }) + logger.info(f"{transform}") + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser + + + diff --git a/tania_scripts/supar/models/dep/sl/transform.py b/tania_scripts/supar/models/dep/sl/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..8603cb86793fc9f49863542634aaa74ffe7e9fe9 --- /dev/null +++ b/tania_scripts/supar/models/dep/sl/transform.py @@ -0,0 +1,102 @@ +from supar.utils.transform import Transform, Sentence +from supar.utils.field import Field +from typing import Iterable, List, Optional, Union +from supar.codelin import D_Tree + +class SLDependency(Transform): + + fields = ['ID', 'FORM', 'LEMMA', 'UPOS', 'XPOS', 'FEATS', 'HEAD', 'DEPREL', 'DEPS', 'MISC', 'LABEL'] + + def __init__( + self, + encoder, + ID: Optional[Union[Field, Iterable[Field]]] = None, + FORM: Optional[Union[Field, Iterable[Field]]] = None, + LEMMA: Optional[Union[Field, Iterable[Field]]] = None, + UPOS: Optional[Union[Field, Iterable[Field]]] = None, + XPOS: Optional[Union[Field, Iterable[Field]]] = None, + FEATS: Optional[Union[Field, Iterable[Field]]] = None, + HEAD: Optional[Union[Field, Iterable[Field]]] = None, + DEPREL: Optional[Union[Field, Iterable[Field]]] = None, + DEPS: Optional[Union[Field, Iterable[Field]]] = None, + MISC: Optional[Union[Field, Iterable[Field]]] = None, + LABEL: Optional[Union[Field, Iterable[Field]]] = None + ): + super().__init__() + + self.encoder = encoder + self.ID = ID + self.FORM = FORM + self.LEMMA = LEMMA + self.UPOS = UPOS + self.XPOS = XPOS + self.FEATS = FEATS + self.HEAD = HEAD + self.DEPREL = DEPREL + self.DEPS = DEPS + self.MISC = MISC + self.LABEL = LABEL + + @property + def src(self): + return self.FORM, self.LEMMA, self.UPOS, self.XPOS, self.FEATS + + @property + def tgt(self): + return self.HEAD, self.DEPREL, self.DEPS, self.MISC, self.LABEL + + def load( + self, + data: str, + **kwargs + ): + lines = open(data) + index, sentence = 0, [] + for line in lines: + line = line.strip() + if len(line) == 0: + sentence = SLDependencySentence(self, sentence, self.encoder, index) + if sentence.values: + yield sentence + index += 1 + sentence = [] + else: + sentence.append(line) + + + +class SLDependencySentence(Sentence): + def __init__(self, transform: SLDependency, lines: List[str], encoder, index: Optional[int] = None): + super().__init__(transform, index) + self.annotations = dict() + self.values = list() + + for i, line in enumerate(lines): + value = line.split('\t') + if value[0].startswith('#') or not value[0].isdigit(): + self.annotations[-i-1] = line + else: + self.annotations[len(self.values)] = line + self.values.append(value) + + # convert values into nodes + tree = D_Tree.from_string('\n'.join(['\t'.join(value) for value in self.values])) + + # linearize tree + linearized_tree = encoder.encode(tree) + + self.values = list(zip(*self.values)) + self.values[6] = tuple(map(int, self.values[6])) + + # add labels + labels = tuple(str(label.xi) for label in linearized_tree.labels) + self.values.append(labels) + + + + def __repr__(self): + # cover the raw lines + merged = {**self.annotations, + **{i: '\t'.join(map(str, line)) + for i, line in enumerate(zip(*self.values[:-1]))}} + return '\n'.join(merged.values()) + '\n' diff --git a/tania_scripts/supar/models/dep/vi/__init__.py b/tania_scripts/supar/models/dep/vi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18dc35558d4f383929d588d3f70d319684afb7ec --- /dev/null +++ b/tania_scripts/supar/models/dep/vi/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import VIDependencyModel +from .parser import VIDependencyParser + +__all__ = ['VIDependencyModel', 'VIDependencyParser'] diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78075f4745b4926add51043c086c330542182c05 Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01b920378d10fc9cdbec8a67d9522ea76e4b0702 Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a980f9eebf2cc021c1f6a146c773ebb86340a75 Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d903b7c1fad5509a97624f635495068bf48f4a90 Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daf22311cfc35994234b5e1731ac3cb9e418b269 Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94aaa00287400fd00c2dabd76be8508ea9161706 Binary files /dev/null and b/tania_scripts/supar/models/dep/vi/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/dep/vi/model.py b/tania_scripts/supar/models/dep/vi/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad8c3a123f08919901ec3c06ed404a6d28385cb --- /dev/null +++ b/tania_scripts/supar/models/dep/vi/model.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from supar.models.dep.biaffine.model import BiaffineDependencyModel +from supar.models.dep.biaffine.transform import CoNLL +from supar.modules import MLP, Biaffine, Triaffine +from supar.structs import (DependencyCRF, DependencyLBP, DependencyMFVI, + MatrixTree) +from supar.utils import Config +from supar.utils.common import MIN + + +class VIDependencyModel(BiaffineDependencyModel): + r""" + The implementation of Dependency Parser using Variational Inference :cite:`wang-tu-2020-second`. + + Args: + n_words (int): + The size of the word vocabulary. + n_rels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [``'char'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 100. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .33. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 800. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layer. Default: .33. + n_arc_mlp (int): + Arc MLP size. Default: 500. + n_sib_mlp (int): + Binary factor MLP size. Default: 100. + n_rel_mlp (int): + Label MLP size. Default: 100. + mlp_dropout (float): + The dropout ratio of MLP layers. Default: .33. + scale (float): + Scaling factor for affine scores. Default: 0. + inference (str): + Approximate inference methods. Default: ``mfvi``. + max_iter (int): + Max iteration times for inference. Default: 3. + interpolation (int): + Constant to even out the label/edge loss. Default: .1. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_rels, + n_tags=None, + n_chars=None, + encoder='lstm', + feat=['char'], + n_embed=100, + n_pretrained=100, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + elmo='original_5b', + elmo_bos_eos=(True, False), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.33, + n_encoder_hidden=800, + n_encoder_layers=3, + encoder_dropout=.33, + n_arc_mlp=500, + n_sib_mlp=100, + n_rel_mlp=100, + mlp_dropout=.33, + scale=0, + inference='mfvi', + max_iter=3, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + self.arc_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout) + self.arc_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_arc_mlp, dropout=mlp_dropout) + self.sib_mlp_s = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout) + self.sib_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout) + self.sib_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_sib_mlp, dropout=mlp_dropout) + self.rel_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout) + self.rel_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_rel_mlp, dropout=mlp_dropout) + + self.arc_attn = Biaffine(n_in=n_arc_mlp, scale=scale, bias_x=True, bias_y=False) + self.sib_attn = Triaffine(n_in=n_sib_mlp, scale=scale, bias_x=True, bias_y=True) + self.rel_attn = Biaffine(n_in=n_rel_mlp, n_out=n_rels, bias_x=True, bias_y=True) + self.inference = (DependencyMFVI if inference == 'mfvi' else DependencyLBP)(max_iter) + self.criterion = nn.CrossEntropyLoss() + + def forward(self, words, feats=None): + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor, ~torch.Tensor: + Scores of all possible arcs (``[batch_size, seq_len, seq_len]``), + dependent-head-sibling triples (``[batch_size, seq_len, seq_len, seq_len]``) and + all possible labels on each arc (``[batch_size, seq_len, seq_len, n_labels]``). + """ + + x = self.encode(words, feats) + mask = words.ne(self.args.pad_index) if len(words.shape) < 3 else words.ne(self.args.pad_index).any(-1) + + arc_d = self.arc_mlp_d(x) + arc_h = self.arc_mlp_h(x) + sib_s = self.sib_mlp_s(x) + sib_d = self.sib_mlp_d(x) + sib_h = self.sib_mlp_h(x) + rel_d = self.rel_mlp_d(x) + rel_h = self.rel_mlp_h(x) + + # [batch_size, seq_len, seq_len] + s_arc = self.arc_attn(arc_d, arc_h).masked_fill_(~mask.unsqueeze(1), MIN) + # [batch_size, seq_len, seq_len, seq_len] + s_sib = self.sib_attn(sib_s, sib_d, sib_h).permute(0, 3, 1, 2) + # [batch_size, seq_len, seq_len, n_rels] + s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1) + + return s_arc, s_sib, s_rel + + def loss(self, s_arc, s_sib, s_rel, arcs, rels, mask): + r""" + Args: + s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible arcs. + s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. + Scores of all possible dependent-head-sibling triples. + s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each arc. + arcs (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard arcs. + rels (~torch.LongTensor): ``[batch_size, seq_len]``. + The tensor of gold-standard labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + + Returns: + ~torch.Tensor: + The training loss. + """ + + arc_loss, marginals = self.inference((s_arc, s_sib), mask, arcs) + s_rel, rels = s_rel[mask], rels[mask] + s_rel = s_rel[torch.arange(len(rels)), arcs[mask]] + rel_loss = self.criterion(s_rel, rels) + loss = arc_loss + rel_loss + return loss, marginals + + def decode(self, s_arc, s_rel, mask, tree=False, proj=False): + r""" + Args: + s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible arcs. + s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each arc. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + tree (bool): + If ``True``, ensures to output well-formed trees. Default: ``False``. + proj (bool): + If ``True``, ensures to output projective trees. Default: ``False``. + + Returns: + ~torch.LongTensor, ~torch.LongTensor: + Predicted arcs and labels of shape ``[batch_size, seq_len]``. + """ + + lens = mask.sum(1) + arc_preds = s_arc.argmax(-1) + bad = [not CoNLL.istree(seq[1:i+1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist())] + if tree and any(bad): + arc_preds[bad] = (DependencyCRF if proj else MatrixTree)(s_arc[bad], mask[bad].sum(-1)).argmax + rel_preds = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1) + + return arc_preds, rel_preds diff --git a/tania_scripts/supar/models/dep/vi/parser.py b/tania_scripts/supar/models/dep/vi/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..3808a3daa43247aa099f753e4a44d1cdfd480720 --- /dev/null +++ b/tania_scripts/supar/models/dep/vi/parser.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- + +from typing import Iterable, Union + +import torch + +from supar.models.dep.biaffine.parser import BiaffineDependencyParser +from supar.models.dep.vi.model import VIDependencyModel +from supar.utils import Config +from supar.utils.fn import ispunct +from supar.utils.logging import get_logger +from supar.utils.metric import AttachmentMetric +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class VIDependencyParser(BiaffineDependencyParser): + r""" + The implementation of Dependency Parser using Variational Inference :cite:`wang-tu-2020-second`. + """ + + NAME = 'vi-dependency' + MODEL = VIDependencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + tree: bool = False, + proj: bool = False, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + punct: bool = False, + tree: bool = True, + proj: bool = True, + partial: bool = False, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + tree: bool = True, + proj: bool = True, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, _, *feats, arcs, rels = batch + mask = batch.mask + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_sib, s_rel = self.model(words, feats) + loss, *_ = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask) + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> AttachmentMetric: + words, _, *feats, arcs, rels = batch + mask = batch.mask + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_sib, s_rel = self.model(words, feats) + loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask) + arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj) + if self.args.partial: + mask &= arcs.ge(0) + # ignore all punctuation if not specified + if not self.args.punct: + mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words])) + return AttachmentMetric(loss, (arc_preds, rel_preds), (arcs, rels), mask) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, _, *feats = batch + mask, lens = batch.mask, (batch.lens - 1).tolist() + # ignore the first token of each sentence + mask[:, 0] = 0 + s_arc, s_sib, s_rel = self.model(words, feats) + s_arc = self.model.inference((s_arc, s_sib), mask) + arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj) + batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)] + batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)] + if self.args.prob: + batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.unbind())] + return batch diff --git a/tania_scripts/supar/models/sdp/__init__.py b/tania_scripts/supar/models/sdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..633e238449fbf947b867006579f439d8b8e78156 --- /dev/null +++ b/tania_scripts/supar/models/sdp/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .biaffine import BiaffineSemanticDependencyModel, BiaffineSemanticDependencyParser +from .vi import VISemanticDependencyModel, VISemanticDependencyParser + +__all__ = ['BiaffineSemanticDependencyModel', 'BiaffineSemanticDependencyParser', + 'VISemanticDependencyModel', 'VISemanticDependencyParser'] diff --git a/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02080dfc284fa4d868f12d9e49fca7f0775a90f5 Binary files /dev/null and b/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f809f66bbc799ef64855403f747872f7bcc4dc85 Binary files /dev/null and b/tania_scripts/supar/models/sdp/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/sdp/biaffine/__init__.py b/tania_scripts/supar/models/sdp/biaffine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab2feeeb4e81b099086b73e78ae81d5fccbbe82c --- /dev/null +++ b/tania_scripts/supar/models/sdp/biaffine/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import BiaffineSemanticDependencyModel +from .parser import BiaffineSemanticDependencyParser + +__all__ = ['BiaffineSemanticDependencyModel', 'BiaffineSemanticDependencyParser'] diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4845ca571234eecab6e183073d4864e1ec667a85 Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9896718732a155aafb0ff99c6979f1448f90fcda Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3927ea658cbef06b9e352886b378c4743b3c16b Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c1c571e2bcd578989340f70f59cc638e9cb644e Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9143362fa65ddfd6a27203f97dd01c2a6a45f3e Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3e9e25a0966252386b63bb170c782847ca49afb Binary files /dev/null and b/tania_scripts/supar/models/sdp/biaffine/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/sdp/biaffine/model.py b/tania_scripts/supar/models/sdp/biaffine/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7afa4dde4f2d7dee3d764ab88d1e9c899a02d7 --- /dev/null +++ b/tania_scripts/supar/models/sdp/biaffine/model.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- + +import torch.nn as nn +from supar.model import Model +from supar.modules import MLP, Biaffine +from supar.utils import Config + + +class BiaffineSemanticDependencyModel(Model): + r""" + The implementation of Biaffine Semantic Dependency Parser :cite:`dozat-manning-2018-simpler`. + + Args: + n_words (int): + The size of the word vocabulary. + n_labels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + n_lemmas (int): + The number of lemmas, required if lemma embeddings are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'lemma'``: Lemma embeddings. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [ ``'tag'``, ``'char'``, ``'lemma'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word representations. Default: 125. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .2. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 1200. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layer. Default: .33. + n_edge_mlp (int): + Edge MLP size. Default: 600. + n_label_mlp (int): + Label MLP size. Default: 600. + edge_mlp_dropout (float): + The dropout ratio of edge MLP layers. Default: .25. + label_mlp_dropout (float): + The dropout ratio of label MLP layers. Default: .33. + interpolation (int): + Constant to even out the label/edge loss. Default: .1. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_labels, + n_tags=None, + n_chars=None, + n_lemmas=None, + encoder='lstm', + feat=['tag', 'char', 'lemma'], + n_embed=100, + n_pretrained=125, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=400, + char_pad_index=0, + char_dropout=0.33, + elmo='original_5b', + elmo_bos_eos=(True, False), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.2, + n_encoder_hidden=1200, + n_encoder_layers=3, + encoder_dropout=.33, + n_edge_mlp=600, + n_label_mlp=600, + edge_mlp_dropout=.25, + label_mlp_dropout=.33, + interpolation=0.1, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + self.edge_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False) + self.edge_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False) + self.label_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False) + self.label_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False) + + self.edge_attn = Biaffine(n_in=n_edge_mlp, n_out=2, bias_x=True, bias_y=True) + self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True) + self.criterion = nn.CrossEntropyLoss() + + def load_pretrained(self, embed=None): + if embed is not None: + self.pretrained = nn.Embedding.from_pretrained(embed) + if embed.shape[1] != self.args.n_pretrained: + self.embed_proj = nn.Linear(embed.shape[1], self.args.n_pretrained) + return self + + def forward(self, words, feats=None): + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first tensor of shape ``[batch_size, seq_len, seq_len, 2]`` holds scores of all possible edges. + The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds + scores of all possible labels on each edge. + """ + + x = self.encode(words, feats) + + edge_d = self.edge_mlp_d(x) + edge_h = self.edge_mlp_h(x) + label_d = self.label_mlp_d(x) + label_h = self.label_mlp_h(x) + + # [batch_size, seq_len, seq_len, 2] + s_edge = self.edge_attn(edge_d, edge_h).permute(0, 2, 3, 1) + # [batch_size, seq_len, seq_len, n_labels] + s_label = self.label_attn(label_d, label_h).permute(0, 2, 3, 1) + + return s_edge, s_label + + def loss(self, s_edge, s_label, labels, mask): + r""" + Args: + s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``. + Scores of all possible edges. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each edge. + labels (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. + The tensor of gold-standard labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + + Returns: + ~torch.Tensor: + The training loss. + """ + + edge_mask = labels.ge(0) & mask + edge_loss = self.criterion(s_edge[mask], edge_mask[mask].long()) + label_loss = self.criterion(s_label[edge_mask], labels[edge_mask]) + return self.args.interpolation * label_loss + (1 - self.args.interpolation) * edge_loss + + def decode(self, s_edge, s_label): + r""" + Args: + s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``. + Scores of all possible edges. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each edge. + + Returns: + ~torch.LongTensor: + Predicted labels of shape ``[batch_size, seq_len, seq_len]``. + """ + + return s_label.argmax(-1).masked_fill_(s_edge.argmax(-1).lt(1), -1) diff --git a/tania_scripts/supar/models/sdp/biaffine/parser.py b/tania_scripts/supar/models/sdp/biaffine/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..c28f6c229a5d33f433b2e19378b7966eb442fb74 --- /dev/null +++ b/tania_scripts/supar/models/sdp/biaffine/parser.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- + +import os +from typing import Iterable, Union + +import torch + +from supar.models.dep.biaffine.transform import CoNLL +from supar.models.sdp.biaffine import BiaffineSemanticDependencyModel +from supar.parser import Parser +from supar.utils import Config, Dataset, Embedding +from supar.utils.common import BOS, PAD, UNK +from supar.utils.field import ChartField, Field, RawField, SubwordField +from supar.utils.logging import get_logger +from supar.utils.metric import ChartMetric +from supar.utils.tokenizer import TransformerTokenizer +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class BiaffineSemanticDependencyParser(Parser): + r""" + The implementation of Biaffine Semantic Dependency Parser :cite:`dozat-manning-2018-simpler`. + """ + + NAME = 'biaffine-semantic-dependency' + MODEL = BiaffineSemanticDependencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.LEMMA = self.transform.LEMMA + self.TAG = self.transform.POS + self.LABEL = self.transform.PHEAD + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, *feats, labels = batch + mask = batch.mask + mask = mask.unsqueeze(1) & mask.unsqueeze(2) + mask[:, 0] = 0 + s_edge, s_label = self.model(words, feats) + loss = self.model.loss(s_edge, s_label, labels, mask) + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> ChartMetric: + words, *feats, labels = batch + mask = batch.mask + mask = mask.unsqueeze(1) & mask.unsqueeze(2) + mask[:, 0] = 0 + s_edge, s_label = self.model(words, feats) + loss = self.model.loss(s_edge, s_label, labels, mask) + label_preds = self.model.decode(s_edge, s_label) + return ChartMetric(loss, label_preds.masked_fill(~mask, -1), labels.masked_fill(~mask, -1)) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, *feats = batch + mask, lens = batch.mask, (batch.lens - 1).tolist() + mask = mask.unsqueeze(1) & mask.unsqueeze(2) + mask[:, 0] = 0 + with torch.autocast(self.device, enabled=self.args.amp): + s_edge, s_label = self.model(words, feats) + label_preds = self.model.decode(s_edge, s_label).masked_fill(~mask, -1) + batch.labels = [CoNLL.build_relations([[self.LABEL.vocab[i] if i >= 0 else None for i in row] + for row in chart[1:i, :i].tolist()]) + for i, chart in zip(lens, label_preds)] + if self.args.prob: + batch.probs = [prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.softmax(-1).unbind())] + return batch + + @classmethod + def build(cls, path, min_freq=7, fix_len=20, **kwargs): + r""" + Build a brand-new Parser, including initialization of all data fields and model parameters. + + Args: + path (str): + The path of the model to be saved. + min_freq (str): + The minimum frequency needed to include a token in the vocabulary. Default:7. + fix_len (int): + The max length of all subword pieces. The excess part of each piece will be truncated. + Required if using CharLSTM/BERT. + Default: 20. + kwargs (Dict): + A dict holding the unconsumed arguments. + """ + + args = Config(**locals()) + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + if os.path.exists(path) and not args.build: + parser = cls.load(**args) + parser.model = cls.MODEL(**parser.args) + parser.model.load_pretrained(parser.transform.FORM[0].embed).to(parser.device) + return parser + + logger.info("Building the fields") + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True) + TAG, CHAR, LEMMA, ELMO, BERT = None, None, None, None, None + if args.encoder == 'bert': + t = TransformerTokenizer(args.bert) + WORD = SubwordField('words', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t) + WORD.vocab = t.vocab + else: + WORD = Field('words', pad=PAD, unk=UNK, bos=BOS, lower=True) + if 'tag' in args.feat: + TAG = Field('tags', bos=BOS) + if 'char' in args.feat: + CHAR = SubwordField('chars', pad=PAD, unk=UNK, bos=BOS, fix_len=args.fix_len) + if 'lemma' in args.feat: + LEMMA = Field('lemmas', pad=PAD, unk=UNK, bos=BOS, lower=True) + if 'elmo' in args.feat: + from allennlp.modules.elmo import batch_to_ids + ELMO = RawField('elmo') + ELMO.compose = lambda x: batch_to_ids(x).to(WORD.device) + if 'bert' in args.feat: + t = TransformerTokenizer(args.bert) + BERT = SubwordField('bert', pad=t.pad, unk=t.unk, bos=t.bos, fix_len=args.fix_len, tokenize=t) + BERT.vocab = t.vocab + LABEL = ChartField('labels', fn=CoNLL.get_labels) + transform = CoNLL(FORM=(WORD, CHAR, ELMO, BERT), LEMMA=LEMMA, POS=TAG, PHEAD=LABEL) + + train = Dataset(transform, args.train, **args) + if args.encoder != 'bert': + WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x)) + if TAG is not None: + TAG.build(train) + if CHAR is not None: + CHAR.build(train) + if LEMMA is not None: + LEMMA.build(train) + LABEL.build(train) + args.update({ + 'n_words': len(WORD.vocab) if args.encoder == 'bert' else WORD.vocab.n_init, + 'n_labels': len(LABEL.vocab), + 'n_tags': len(TAG.vocab) if TAG is not None else None, + 'n_chars': len(CHAR.vocab) if CHAR is not None else None, + 'char_pad_index': CHAR.pad_index if CHAR is not None else None, + 'n_lemmas': len(LEMMA.vocab) if LEMMA is not None else None, + 'bert_pad_index': BERT.pad_index if BERT is not None else None, + 'pad_index': WORD.pad_index, + 'unk_index': WORD.unk_index, + 'bos_index': WORD.bos_index + }) + logger.info(f"{transform}") + + logger.info("Building the model") + model = cls.MODEL(**args).load_pretrained(WORD.embed if hasattr(WORD, 'embed') else None) + logger.info(f"{model}\n") + + parser = cls(args, model, transform) + parser.model.to(parser.device) + return parser diff --git a/tania_scripts/supar/models/sdp/vi/__init__.py b/tania_scripts/supar/models/sdp/vi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2aae65de0fcf695c9c98e76c3f6575e9b69c5577 --- /dev/null +++ b/tania_scripts/supar/models/sdp/vi/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .model import VISemanticDependencyModel +from .parser import VISemanticDependencyParser + +__all__ = ['VISemanticDependencyModel', 'VISemanticDependencyParser'] diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b7e00e74f70f0d965c505e8994b7e119a0dec34 Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e17f748d964d58b705b5a14d52e4f3758f8318e7 Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-310.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cfb7bd073d93cc874641db0353d3cc92dcbf402 Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-311.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac266759295a48ebd875a15b1b28e225c633c0b6 Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/model.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-310.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceb658d3c6b5f0ceae68bdf73fc482f9763c6d5f Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-310.pyc differ diff --git a/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-311.pyc b/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b28e31b2c498c355f12078f27de07f420d6c8a0 Binary files /dev/null and b/tania_scripts/supar/models/sdp/vi/__pycache__/parser.cpython-311.pyc differ diff --git a/tania_scripts/supar/models/sdp/vi/model.py b/tania_scripts/supar/models/sdp/vi/model.py new file mode 100644 index 0000000000000000000000000000000000000000..12c20e1af605ee1c918ddcafe3bb64a9baa52146 --- /dev/null +++ b/tania_scripts/supar/models/sdp/vi/model.py @@ -0,0 +1,471 @@ +# -*- coding: utf-8 -*- + +import torch.nn as nn +from supar.model import Model +from supar.modules import MLP, Biaffine, Triaffine +from supar.structs import SemanticDependencyLBP, SemanticDependencyMFVI +from supar.utils import Config + + +class BiaffineSemanticDependencyModel(Model): + r""" + The implementation of Biaffine Semantic Dependency Parser :cite:`dozat-manning-2018-simpler`. + + Args: + n_words (int): + The size of the word vocabulary. + n_labels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + n_lemmas (int): + The number of lemmas, required if lemma embeddings are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'lemma'``: Lemma embeddings. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [ ``'tag'``, ``'char'``, ``'lemma'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word representations. Default: 125. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .2. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 1200. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layer. Default: .33. + n_edge_mlp (int): + Edge MLP size. Default: 600. + n_label_mlp (int): + Label MLP size. Default: 600. + edge_mlp_dropout (float): + The dropout ratio of edge MLP layers. Default: .25. + label_mlp_dropout (float): + The dropout ratio of label MLP layers. Default: .33. + interpolation (int): + Constant to even out the label/edge loss. Default: .1. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_labels, + n_tags=None, + n_chars=None, + n_lemmas=None, + encoder='lstm', + feat=['tag', 'char', 'lemma'], + n_embed=100, + n_pretrained=125, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=400, + char_pad_index=0, + char_dropout=0.33, + elmo='original_5b', + elmo_bos_eos=(True, False), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.2, + n_encoder_hidden=1200, + n_encoder_layers=3, + encoder_dropout=.33, + n_edge_mlp=600, + n_label_mlp=600, + edge_mlp_dropout=.25, + label_mlp_dropout=.33, + interpolation=0.1, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + self.edge_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False) + self.edge_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False) + self.label_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False) + self.label_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False) + + self.edge_attn = Biaffine(n_in=n_edge_mlp, n_out=2, bias_x=True, bias_y=True) + self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True) + self.criterion = nn.CrossEntropyLoss() + + def load_pretrained(self, embed=None): + if embed is not None: + self.pretrained = nn.Embedding.from_pretrained(embed) + if embed.shape[1] != self.args.n_pretrained: + self.embed_proj = nn.Linear(embed.shape[1], self.args.n_pretrained) + return self + + def forward(self, words, feats=None): + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first tensor of shape ``[batch_size, seq_len, seq_len, 2]`` holds scores of all possible edges. + The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds + scores of all possible labels on each edge. + """ + + x = self.encode(words, feats) + + edge_d = self.edge_mlp_d(x) + edge_h = self.edge_mlp_h(x) + label_d = self.label_mlp_d(x) + label_h = self.label_mlp_h(x) + + # [batch_size, seq_len, seq_len, 2] + s_edge = self.edge_attn(edge_d, edge_h).permute(0, 2, 3, 1) + # [batch_size, seq_len, seq_len, n_labels] + s_label = self.label_attn(label_d, label_h).permute(0, 2, 3, 1) + + return s_edge, s_label + + def loss(self, s_edge, s_label, labels, mask): + r""" + Args: + s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``. + Scores of all possible edges. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each edge. + labels (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. + The tensor of gold-standard labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + + Returns: + ~torch.Tensor: + The training loss. + """ + + edge_mask = labels.ge(0) & mask + edge_loss = self.criterion(s_edge[mask], edge_mask[mask].long()) + label_loss = self.criterion(s_label[edge_mask], labels[edge_mask]) + return self.args.interpolation * label_loss + (1 - self.args.interpolation) * edge_loss + + def decode(self, s_edge, s_label): + r""" + Args: + s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len, 2]``. + Scores of all possible edges. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each edge. + + Returns: + ~torch.LongTensor: + Predicted labels of shape ``[batch_size, seq_len, seq_len]``. + """ + + return s_label.argmax(-1).masked_fill_(s_edge.argmax(-1).lt(1), -1) + + +class VISemanticDependencyModel(BiaffineSemanticDependencyModel): + r""" + The implementation of Semantic Dependency Parser using Variational Inference :cite:`wang-etal-2019-second`. + + Args: + n_words (int): + The size of the word vocabulary. + n_labels (int): + The number of labels in the treebank. + n_tags (int): + The number of POS tags, required if POS tag embeddings are used. Default: ``None``. + n_chars (int): + The number of characters, required if character-level representations are used. Default: ``None``. + n_lemmas (int): + The number of lemmas, required if lemma embeddings are used. Default: ``None``. + encoder (str): + Encoder to use. + ``'lstm'``: BiLSTM encoder. + ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. + Default: ``'lstm'``. + feat (List[str]): + Additional features to use, required if ``encoder='lstm'``. + ``'tag'``: POS tag embeddings. + ``'char'``: Character-level representations extracted by CharLSTM. + ``'lemma'``: Lemma embeddings. + ``'bert'``: BERT representations, other pretrained language models like RoBERTa are also feasible. + Default: [ ``'tag'``, ``'char'``, ``'lemma'``]. + n_embed (int): + The size of word embeddings. Default: 100. + n_pretrained (int): + The size of pretrained word embeddings. Default: 125. + n_feat_embed (int): + The size of feature representations. Default: 100. + n_char_embed (int): + The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. + n_char_hidden (int): + The size of hidden states of CharLSTM, required if using CharLSTM. Default: 100. + char_pad_index (int): + The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. + elmo (str): + Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. + elmo_bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. + Default: ``(True, False)``. + bert (str): + Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. + This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. + Default: ``None``. + n_bert_layers (int): + Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. + The final outputs would be weighted sum of the hidden states of these layers. + Default: 4. + mix_dropout (float): + The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. + bert_pooling (str): + Pooling way to get token embeddings. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + bert_pad_index (int): + The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. + Default: 0. + finetune (bool): + If ``False``, freezes all parameters, required if using pretrained layers. Default: ``False``. + n_plm_embed (int): + The size of PLM embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + embed_dropout (float): + The dropout ratio of input embeddings. Default: .2. + n_encoder_hidden (int): + The size of encoder hidden states. Default: 1200. + n_encoder_layers (int): + The number of encoder layers. Default: 3. + encoder_dropout (float): + The dropout ratio of encoder layer. Default: .33. + n_edge_mlp (int): + Unary factor MLP size. Default: 600. + n_pair_mlp (int): + Binary factor MLP size. Default: 150. + n_label_mlp (int): + Label MLP size. Default: 600. + edge_mlp_dropout (float): + The dropout ratio of unary edge factor MLP layers. Default: .25. + pair_mlp_dropout (float): + The dropout ratio of binary factor MLP layers. Default: .25. + label_mlp_dropout (float): + The dropout ratio of label MLP layers. Default: .33. + inference (str): + Approximate inference methods. Default: ``mfvi``. + max_iter (int): + Max iteration times for inference. Default: 3. + interpolation (int): + Constant to even out the label/edge loss. Default: .1. + pad_index (int): + The index of the padding token in the word vocabulary. Default: 0. + unk_index (int): + The index of the unknown token in the word vocabulary. Default: 1. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__(self, + n_words, + n_labels, + n_tags=None, + n_chars=None, + n_lemmas=None, + encoder='lstm', + feat=['tag', 'char', 'lemma'], + n_embed=100, + n_pretrained=125, + n_feat_embed=100, + n_char_embed=50, + n_char_hidden=100, + char_pad_index=0, + char_dropout=0, + elmo='original_5b', + elmo_bos_eos=(True, False), + bert=None, + n_bert_layers=4, + mix_dropout=.0, + bert_pooling='mean', + bert_pad_index=0, + finetune=False, + n_plm_embed=0, + embed_dropout=.2, + n_encoder_hidden=1200, + n_encoder_layers=3, + encoder_dropout=.33, + n_edge_mlp=600, + n_pair_mlp=150, + n_label_mlp=600, + edge_mlp_dropout=.25, + pair_mlp_dropout=.25, + label_mlp_dropout=.33, + inference='mfvi', + max_iter=3, + interpolation=0.1, + pad_index=0, + unk_index=1, + **kwargs): + super().__init__(**Config().update(locals())) + + self.edge_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False) + self.edge_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_edge_mlp, dropout=edge_mlp_dropout, activation=False) + self.pair_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=pair_mlp_dropout, activation=False) + self.pair_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=pair_mlp_dropout, activation=False) + self.pair_mlp_g = MLP(n_in=self.args.n_encoder_hidden, n_out=n_pair_mlp, dropout=pair_mlp_dropout, activation=False) + self.label_mlp_d = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False) + self.label_mlp_h = MLP(n_in=self.args.n_encoder_hidden, n_out=n_label_mlp, dropout=label_mlp_dropout, activation=False) + + self.edge_attn = Biaffine(n_in=n_edge_mlp, bias_x=True, bias_y=True) + self.sib_attn = Triaffine(n_in=n_pair_mlp, bias_x=True, bias_y=True) + self.cop_attn = Triaffine(n_in=n_pair_mlp, bias_x=True, bias_y=True) + self.grd_attn = Triaffine(n_in=n_pair_mlp, bias_x=True, bias_y=True) + self.label_attn = Biaffine(n_in=n_label_mlp, n_out=n_labels, bias_x=True, bias_y=True) + self.inference = (SemanticDependencyMFVI if inference == 'mfvi' else SemanticDependencyLBP)(max_iter) + self.criterion = nn.CrossEntropyLoss() + + def forward(self, words, feats=None): + r""" + Args: + words (~torch.LongTensor): ``[batch_size, seq_len]``. + Word indices. + feats (List[~torch.LongTensor]): + A list of feat indices. + The size is either ``[batch_size, seq_len, fix_len]`` if ``feat`` is ``'char'`` or ``'bert'``, + or ``[batch_size, seq_len]`` otherwise. + Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~torch.Tensor: + The first and last are scores of all possible edges of shape ``[batch_size, seq_len, seq_len]`` + and possible labels on each edge of shape ``[batch_size, seq_len, seq_len, n_labels]``. + Others are scores of second-order sibling, coparent and grandparent factors + (``[batch_size, seq_len, seq_len, seq_len]``). + + """ + + x = self.encode(words, feats) + + edge_d = self.edge_mlp_d(x) + edge_h = self.edge_mlp_h(x) + pair_d = self.pair_mlp_d(x) + pair_h = self.pair_mlp_h(x) + pair_g = self.pair_mlp_g(x) + label_d = self.label_mlp_d(x) + label_h = self.label_mlp_h(x) + + # [batch_size, seq_len, seq_len] + s_edge = self.edge_attn(edge_d, edge_h) + # [batch_size, seq_len, seq_len, seq_len], (d->h->s) + s_sib = self.sib_attn(pair_d, pair_d, pair_h) + s_sib = (s_sib.triu() + s_sib.triu(1).transpose(-1, -2)).permute(0, 3, 1, 2) + # [batch_size, seq_len, seq_len, seq_len], (d->h->c) + s_cop = self.cop_attn(pair_h, pair_d, pair_h).permute(0, 3, 1, 2) + s_cop = s_cop.triu() + s_cop.triu(1).transpose(-1, -2) + # [batch_size, seq_len, seq_len, seq_len], (d->h->g) + s_grd = self.grd_attn(pair_g, pair_d, pair_h).permute(0, 3, 1, 2) + # [batch_size, seq_len, seq_len, n_labels] + s_label = self.label_attn(label_d, label_h).permute(0, 2, 3, 1) + + return s_edge, s_sib, s_cop, s_grd, s_label + + def loss(self, s_edge, s_sib, s_cop, s_grd, s_label, labels, mask): + r""" + Args: + s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible edges. + s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. + Scores of all possible dependent-head-sibling triples. + s_cop (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. + Scores of all possible dependent-head-coparent triples. + s_grd (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. + Scores of all possible dependent-head-grandparent triples. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each edge. + labels (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. + The tensor of gold-standard labels. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The training loss and marginals of shape ``[batch_size, seq_len, seq_len]``. + """ + + edge_mask = labels.ge(0) & mask + edge_loss, marginals = self.inference((s_edge, s_sib, s_cop, s_grd), mask, edge_mask.long()) + label_loss = self.criterion(s_label[edge_mask], labels[edge_mask]) + loss = self.args.interpolation * label_loss + (1 - self.args.interpolation) * edge_loss + return loss, marginals + + def decode(self, s_edge, s_label): + r""" + Args: + s_edge (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible edges. + s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``. + Scores of all possible labels on each edge. + + Returns: + ~torch.LongTensor: + Predicted labels of shape ``[batch_size, seq_len, seq_len]``. + """ + + return s_label.argmax(-1).masked_fill_(s_edge.lt(0.5), -1) diff --git a/tania_scripts/supar/models/sdp/vi/parser.py b/tania_scripts/supar/models/sdp/vi/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb85c863e992f5f303537867925c816b6d7692c --- /dev/null +++ b/tania_scripts/supar/models/sdp/vi/parser.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + +from typing import Iterable, Union + +import torch + +from supar.models.dep.biaffine.transform import CoNLL +from supar.models.sdp.biaffine.parser import BiaffineSemanticDependencyParser +from supar.models.sdp.vi.model import VISemanticDependencyModel +from supar.utils import Config +from supar.utils.logging import get_logger +from supar.utils.metric import ChartMetric +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class VISemanticDependencyParser(BiaffineSemanticDependencyParser): + r""" + The implementation of Semantic Dependency Parser using Variational Inference :cite:`wang-etal-2019-second`. + """ + + NAME = 'vi-semantic-dependency' + MODEL = VISemanticDependencyModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.LEMMA = self.transform.LEMMA + self.TAG = self.transform.POS + self.LABEL = self.transform.PHEAD + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int = 1000, + patience: int = 100, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().train(**Config().update(locals())) + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().evaluate(**Config().update(locals())) + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + return super().predict(**Config().update(locals())) + + def train_step(self, batch: Batch) -> torch.Tensor: + words, *feats, labels = batch + mask = batch.mask + mask = mask.unsqueeze(1) & mask.unsqueeze(2) + mask[:, 0] = 0 + s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats) + loss, s_edge = self.model.loss(s_edge, s_sib, s_cop, s_grd, s_label, labels, mask) + return loss + + @torch.no_grad() + def eval_step(self, batch: Batch) -> ChartMetric: + words, *feats, labels = batch + mask = batch.mask + mask = mask.unsqueeze(1) & mask.unsqueeze(2) + mask[:, 0] = 0 + s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats) + loss, s_edge = self.model.loss(s_edge, s_sib, s_cop, s_grd, s_label, labels, mask) + label_preds = self.model.decode(s_edge, s_label) + return ChartMetric(loss, label_preds.masked_fill(~mask, -1), labels.masked_fill(~mask, -1)) + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + words, *feats = batch + mask, lens = batch.mask, (batch.lens - 1).tolist() + mask = mask.unsqueeze(1) & mask.unsqueeze(2) + mask[:, 0] = 0 + s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats) + s_edge = self.model.inference((s_edge, s_sib, s_cop, s_grd), mask) + label_preds = self.model.decode(s_edge, s_label).masked_fill(~mask, -1) + batch.labels = [CoNLL.build_relations([[self.LABEL.vocab[i] if i >= 0 else None for i in row] + for row in chart[1:i, :i].tolist()]) + for i, chart in zip(lens, label_preds)] + if self.args.prob: + batch.probs = [prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.unbind())] + return batch diff --git a/tania_scripts/supar/modules/__init__.py b/tania_scripts/supar/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..981a6bfbfd0d05203f949840d34d90e2a9097bad --- /dev/null +++ b/tania_scripts/supar/modules/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +from .affine import Biaffine, Triaffine +from .dropout import IndependentDropout, SharedDropout, TokenDropout +from .gnn import GraphConvolutionalNetwork +from .lstm import CharLSTM, VariationalLSTM +from .mlp import MLP +from .decoder import DecoderLSTM +from .pretrained import ELMoEmbedding, TransformerEmbedding +from .transformer import (TransformerDecoder, TransformerEncoder, + TransformerWordEmbedding) + +__all__ = ['Biaffine', 'Triaffine', + 'IndependentDropout', 'SharedDropout', 'TokenDropout', + 'GraphConvolutionalNetwork', + 'CharLSTM', 'VariationalLSTM', + 'MLP', 'DecoderLSTM', + 'ELMoEmbedding', 'TransformerEmbedding', + 'TransformerWordEmbedding', + 'TransformerDecoder', 'TransformerEncoder'] diff --git a/tania_scripts/supar/modules/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c13e44da2a22e9ce2413ca58f980bfe9c2110b8 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8dc3d835d3ccc7fab896ae595885b0226ecebc2 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/affine.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/affine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39667f1870be23948663babfaa2080b05451c54f Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/affine.cpython-310.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/affine.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/affine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c47fed9d260e0af01681cd36e39620f2209a26b3 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/affine.cpython-311.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/decoder.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76a4b2187f9eabbe13f6151c3b2f5c962b577852 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/decoder.cpython-310.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/decoder.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6de6a887d7ce612ec0a7ef15992c489fb6b21aa7 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/decoder.cpython-311.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/dropout.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/dropout.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9431bcc12f190fe9a1715caba2b9b9b578ef7e8 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/dropout.cpython-310.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/dropout.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/dropout.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24e3f7b05eeb7b0809eb5b72992b12d9e5c160d4 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/dropout.cpython-311.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/gnn.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/gnn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c9899ee4c9fc3620942fc1768b69b489f133280 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/gnn.cpython-310.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/gnn.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/gnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7471454ca9e111edcc54852b2bfeafb07ef79179 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/gnn.cpython-311.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/lstm.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/lstm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b41ac7e996529a8790ae248cd831e01297c48eb3 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/lstm.cpython-310.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/lstm.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/lstm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbe023f8403b631f77886bc738ca050d7e7cf659 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/lstm.cpython-311.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/mlp.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..856587590500e87d19034d97b32fba8d05cf1ddb Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/mlp.cpython-310.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/mlp.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69c92f7adbd0833ce5ccf154ad0997ff3e2cb672 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/mlp.cpython-311.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/pretrained.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/pretrained.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe5270597f6a76ea7e0866c0b20e3227024c7199 Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/pretrained.cpython-310.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/pretrained.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/pretrained.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75014d025196b5be3ee840b78c519e0d692f6c4c Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/pretrained.cpython-311.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/transformer.cpython-310.pyc b/tania_scripts/supar/modules/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aa648d7943086a791868bfe31ccd2dac17e4d0e Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/transformer.cpython-310.pyc differ diff --git a/tania_scripts/supar/modules/__pycache__/transformer.cpython-311.pyc b/tania_scripts/supar/modules/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb5927d047ca6f9c5b2983417896033e2322681e Binary files /dev/null and b/tania_scripts/supar/modules/__pycache__/transformer.cpython-311.pyc differ diff --git a/tania_scripts/supar/modules/affine.py b/tania_scripts/supar/modules/affine.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5defbfddebaf66a687b0ed15c0a4e557a69683 --- /dev/null +++ b/tania_scripts/supar/modules/affine.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Callable, Optional + +import torch +import torch.nn as nn +from supar.modules.mlp import MLP + + +class Biaffine(nn.Module): + r""" + Biaffine layer for first-order scoring :cite:`dozat-etal-2017-biaffine`. + + This function has a tensor of weights :math:`W` and bias terms if needed. + The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y / d^s`, + where `d` and `s` are vector dimension and scaling factor respectively. + :math:`x` and :math:`y` can be concatenated with bias terms. + + Args: + n_in (int): + The size of the input feature. + n_out (int): + The number of output channels. + n_proj (Optional[int]): + If specified, applies MLP layers to reduce vector dimensions. Default: ``None``. + dropout (Optional[float]): + If specified, applies a :class:`SharedDropout` layer with the ratio on MLP outputs. Default: 0. + scale (float): + Factor to scale the scores. Default: 0. + bias_x (bool): + If ``True``, adds a bias term for tensor :math:`x`. Default: ``True``. + bias_y (bool): + If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``. + decompose (bool): + If ``True``, represents the weight as the product of 2 independent matrices. Default: ``False``. + init (Callable): + Callable initialization method. Default: `nn.init.zeros_`. + """ + + def __init__( + self, + n_in: int, + n_out: int = 1, + n_proj: Optional[int] = None, + dropout: Optional[float] = 0, + scale: int = 0, + bias_x: bool = True, + bias_y: bool = True, + decompose: bool = False, + init: Callable = nn.init.zeros_ + ) -> Biaffine: + super().__init__() + + self.n_in = n_in + self.n_out = n_out + self.n_proj = n_proj + self.dropout = dropout + self.scale = scale + self.bias_x = bias_x + self.bias_y = bias_y + self.decompose = decompose + self.init = init + + if n_proj is not None: + self.mlp_x, self.mlp_y = MLP(n_in, n_proj, dropout), MLP(n_in, n_proj, dropout) + self.n_model = n_proj or n_in + if not decompose: + self.weight = nn.Parameter(torch.Tensor(n_out, self.n_model + bias_x, self.n_model + bias_y)) + else: + self.weight = nn.ParameterList((nn.Parameter(torch.Tensor(n_out, self.n_model + bias_x)), + nn.Parameter(torch.Tensor(n_out, self.n_model + bias_y)))) + + self.reset_parameters() + + def __repr__(self): + s = f"n_in={self.n_in}" + if self.n_out > 1: + s += f", n_out={self.n_out}" + if self.n_proj is not None: + s += f", n_proj={self.n_proj}" + if self.dropout > 0: + s += f", dropout={self.dropout}" + if self.scale != 0: + s += f", scale={self.scale}" + if self.bias_x: + s += f", bias_x={self.bias_x}" + if self.bias_y: + s += f", bias_y={self.bias_y}" + if self.decompose: + s += f", decompose={self.decompose}" + return f"{self.__class__.__name__}({s})" + + def reset_parameters(self): + if self.decompose: + for i in self.weight: + self.init(i) + else: + self.init(self.weight) + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor + ) -> torch.Tensor: + r""" + Args: + x (torch.Tensor): ``[batch_size, seq_len, n_in]``. + y (torch.Tensor): ``[batch_size, seq_len, n_in]``. + + Returns: + ~torch.Tensor: + A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len]``. + If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically. + """ + + if hasattr(self, 'mlp_x'): + x, y = self.mlp_x(x), self.mlp_y(y) + if self.bias_x: + x = torch.cat((x, torch.ones_like(x[..., :1])), -1) + if self.bias_y: + y = torch.cat((y, torch.ones_like(y[..., :1])), -1) + # [batch_size, n_out, seq_len, seq_len] + if self.decompose: + wx = torch.einsum('bxi,oi->box', x, self.weight[0]) + wy = torch.einsum('byj,oj->boy', y, self.weight[1]) + s = torch.einsum('box,boy->boxy', wx, wy) + else: + s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y) + return s.squeeze(1) / self.n_in ** self.scale + + +class Triaffine(nn.Module): + r""" + Triaffine layer for second-order scoring :cite:`zhang-etal-2020-efficient,wang-etal-2019-second`. + + This function has a tensor of weights :math:`W` and bias terms if needed. + The score :math:`s(x, y, z)` of the vector triple :math:`(x, y, z)` is computed as :math:`x^T z^T W y / d^s`, + where `d` and `s` are vector dimension and scaling factor respectively. + :math:`x` and :math:`y` can be concatenated with bias terms. + + Args: + n_in (int): + The size of the input feature. + n_out (int): + The number of output channels. + n_proj (Optional[int]): + If specified, applies MLP layers to reduce vector dimensions. Default: ``None``. + dropout (Optional[float]): + If specified, applies a :class:`SharedDropout` layer with the ratio on MLP outputs. Default: 0. + scale (float): + Factor to scale the scores. Default: 0. + bias_x (bool): + If ``True``, adds a bias term for tensor :math:`x`. Default: ``False``. + bias_y (bool): + If ``True``, adds a bias term for tensor :math:`y`. Default: ``False``. + decompose (bool): + If ``True``, represents the weight as the product of 3 independent matrices. Default: ``False``. + init (Callable): + Callable initialization method. Default: `nn.init.zeros_`. + """ + + def __init__( + self, + n_in: int, + n_out: int = 1, + n_proj: Optional[int] = None, + dropout: Optional[float] = 0, + scale: int = 0, + bias_x: bool = False, + bias_y: bool = False, + decompose: bool = False, + init: Callable = nn.init.zeros_ + ) -> Triaffine: + super().__init__() + + self.n_in = n_in + self.n_out = n_out + self.n_proj = n_proj + self.dropout = dropout + self.scale = scale + self.bias_x = bias_x + self.bias_y = bias_y + self.decompose = decompose + self.init = init + + if n_proj is not None: + self.mlp_x = MLP(n_in, n_proj, dropout) + self.mlp_y = MLP(n_in, n_proj, dropout) + self.mlp_z = MLP(n_in, n_proj, dropout) + self.n_model = n_proj or n_in + if not decompose: + self.weight = nn.Parameter(torch.Tensor(n_out, self.n_model + bias_x, self.n_model, self.n_model + bias_y)) + else: + self.weight = nn.ParameterList((nn.Parameter(torch.Tensor(n_out, self.n_model + bias_x)), + nn.Parameter(torch.Tensor(n_out, self.n_model)), + nn.Parameter(torch.Tensor(n_out, self.n_model + bias_y)))) + + self.reset_parameters() + + def __repr__(self): + s = f"n_in={self.n_in}" + if self.n_out > 1: + s += f", n_out={self.n_out}" + if self.n_proj is not None: + s += f", n_proj={self.n_proj}" + if self.dropout > 0: + s += f", dropout={self.dropout}" + if self.scale != 0: + s += f", scale={self.scale}" + if self.bias_x: + s += f", bias_x={self.bias_x}" + if self.bias_y: + s += f", bias_y={self.bias_y}" + if self.decompose: + s += f", decompose={self.decompose}" + return f"{self.__class__.__name__}({s})" + + def reset_parameters(self): + if self.decompose: + for i in self.weight: + self.init(i) + else: + self.init(self.weight) + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor + ) -> torch.Tensor: + r""" + Args: + x (torch.Tensor): ``[batch_size, seq_len, n_in]``. + y (torch.Tensor): ``[batch_size, seq_len, n_in]``. + z (torch.Tensor): ``[batch_size, seq_len, n_in]``. + + Returns: + ~torch.Tensor: + A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len, seq_len]``. + If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically. + """ + + if hasattr(self, 'mlp_x'): + x, y, z = self.mlp_x(x), self.mlp_y(y), self.mlp_z(y) + if self.bias_x: + x = torch.cat((x, torch.ones_like(x[..., :1])), -1) + if self.bias_y: + y = torch.cat((y, torch.ones_like(y[..., :1])), -1) + # [batch_size, n_out, seq_len, seq_len, seq_len] + if self.decompose: + wx = torch.einsum('bxi,oi->box', x, self.weight[0]) + wz = torch.einsum('bzk,ok->boz', z, self.weight[1]) + wy = torch.einsum('byj,oj->boy', y, self.weight[2]) + s = torch.einsum('box,boz,boy->bozxy', wx, wz, wy) + else: + w = torch.einsum('bzk,oikj->bozij', z, self.weight) + s = torch.einsum('bxi,bozij,byj->bozxy', x, w, y) + return s.squeeze(1) / self.n_in ** self.scale diff --git a/tania_scripts/supar/modules/decoder.py b/tania_scripts/supar/modules/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9f97f310628c589ab16520c288c1b02f11c050 --- /dev/null +++ b/tania_scripts/supar/modules/decoder.py @@ -0,0 +1,38 @@ +import torch.nn as nn +from supar.modules.mlp import MLP +import torch + +class DecoderLSTM(nn.Module): + def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int, dropout: float, device: str): + super().__init__() + + self.input_size, self.hidden_size = input_size, hidden_size + self.output_size = output_size + self.num_layers = num_layers + + self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, + num_layers=num_layers, bidirectional=False, + dropout=dropout, batch_first=True) + + self.mlp = MLP(n_in=hidden_size, n_out=output_size, dropout=dropout, + activation=True) + self.device = device + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + :param x: torch.Tensor [batch_size, seq_len, input_size] + :returns torch.Tensor [batch_size, seq_len, output_size] + """ + batch_size, seq_len, _ = x.shape + + # LSTM forward pass + h0, c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device), \ + torch.zeros(self.num_layers, batch_size, self.hidden_size).to(self.device) + hn, cn = self.lstm(x, (h0, c0)) + + # MLP forward pass + output = self.mlp(hn.reshape(batch_size*seq_len, self.hidden_size)) + return output.reshape(batch_size, seq_len, self.output_size) + + def __repr__(self): + return f'DecoderLSTM(input_size={self.input_size}, hidden_size={self.hidden_size}, num_layers={self.num_layers}, output_size={self.output_size}' diff --git a/tania_scripts/supar/modules/dropout.py b/tania_scripts/supar/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3c94705e5b02bb6e1a3df1fc8428377f40684a --- /dev/null +++ b/tania_scripts/supar/modules/dropout.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import List + +import torch +import torch.nn as nn + + +class TokenDropout(nn.Module): + r""" + :class:`TokenDropout` seeks to randomly zero the vectors of some tokens with the probability of `p`. + + Args: + p (float): + The probability of an element to be zeroed. Default: 0.5. + + Examples: + >>> batch_size, seq_len, hidden_size = 1, 3, 5 + >>> x = torch.ones(batch_size, seq_len, hidden_size) + >>> nn.Dropout()(x) + tensor([[[0., 2., 2., 0., 0.], + [2., 2., 0., 2., 2.], + [2., 2., 2., 2., 0.]]]) + >>> TokenDropout()(x) + tensor([[[2., 2., 2., 2., 2.], + [0., 0., 0., 0., 0.], + [2., 2., 2., 2., 2.]]]) + """ + + def __init__(self, p: float = 0.5) -> TokenDropout: + super().__init__() + + self.p = p + + def __repr__(self): + return f"{self.__class__.__name__}(p={self.p})" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): + A tensor of any shape. + Returns: + A tensor with the same shape as `x`. + """ + + if not self.training: + return x + return x * (x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) / (1 - self.p)).unsqueeze(-1) + + +class SharedDropout(nn.Module): + r""" + :class:`SharedDropout` differs from the vanilla dropout strategy in that the dropout mask is shared across one dimension. + + Args: + p (float): + The probability of an element to be zeroed. Default: 0.5. + batch_first (bool): + If ``True``, the input and output tensors are provided as ``[batch_size, seq_len, *]``. + Default: ``True``. + + Examples: + >>> batch_size, seq_len, hidden_size = 1, 3, 5 + >>> x = torch.ones(batch_size, seq_len, hidden_size) + >>> nn.Dropout()(x) + tensor([[[0., 2., 2., 0., 0.], + [2., 2., 0., 2., 2.], + [2., 2., 2., 2., 0.]]]) + >>> SharedDropout()(x) + tensor([[[2., 0., 2., 0., 2.], + [2., 0., 2., 0., 2.], + [2., 0., 2., 0., 2.]]]) + """ + + def __init__(self, p: float = 0.5, batch_first: bool = True) -> SharedDropout: + super().__init__() + + self.p = p + self.batch_first = batch_first + + def __repr__(self): + s = f"p={self.p}" + if self.batch_first: + s += f", batch_first={self.batch_first}" + return f"{self.__class__.__name__}({s})" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): + A tensor of any shape. + Returns: + A tensor with the same shape as `x`. + """ + + if not self.training: + return x + return x * self.get_mask(x[:, 0], self.p).unsqueeze(1) if self.batch_first else self.get_mask(x[0], self.p) + + @staticmethod + def get_mask(x: torch.Tensor, p: float) -> torch.FloatTensor: + return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p) + + +class IndependentDropout(nn.Module): + r""" + For :math:`N` tensors, they use different dropout masks respectively. + When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M` to compensate, + and when all of them are dropped together, zeros are returned. + + Args: + p (float): + The probability of an element to be zeroed. Default: 0.5. + + Examples: + >>> batch_size, seq_len, hidden_size = 1, 3, 5 + >>> x, y = torch.ones(batch_size, seq_len, hidden_size), torch.ones(batch_size, seq_len, hidden_size) + >>> x, y = IndependentDropout()(x, y) + >>> x + tensor([[[1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0.], + [2., 2., 2., 2., 2.]]]) + >>> y + tensor([[[1., 1., 1., 1., 1.], + [2., 2., 2., 2., 2.], + [0., 0., 0., 0., 0.]]]) + """ + + def __init__(self, p: float = 0.5) -> IndependentDropout: + super().__init__() + + self.p = p + + def __repr__(self): + return f"{self.__class__.__name__}(p={self.p})" + + def forward(self, *items: List[torch.Tensor]) -> List[torch.Tensor]: + r""" + Args: + items (List[~torch.Tensor]): + A list of tensors that have the same shape except the last dimension. + Returns: + A tensors are of the same shape as `items`. + """ + + if not self.training: + return items + masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) for x in items] + total = sum(masks) + scale = len(items) / total.max(torch.ones_like(total)) + masks = [mask * scale for mask in masks] + return [item * mask.unsqueeze(-1) for item, mask in zip(items, masks)] diff --git a/tania_scripts/supar/modules/gnn.py b/tania_scripts/supar/modules/gnn.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2108c20a02ed4a1eb18d468ab2084d365f2b43 --- /dev/null +++ b/tania_scripts/supar/modules/gnn.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import torch +import torch.nn as nn + + +class GraphConvolutionalNetwork(nn.Module): + r""" + Multiple GCN layers with layer normalization and residual connections, each executing the operator + from the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper + + .. math:: + \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, + + where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with inserted self-loops + and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. + + Its node-wise formulation is given by: + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in + \mathcal{N}(v) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j + \hat{d}_i}} \mathbf{x}_j + + with :math:`\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}`, where + :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target + node :obj:`i` (default: :obj:`1.0`) + + Args: + n_model (int): + The size of node feature vectors. + n_layers (int): + The number of GCN layers. Default: 1. + selfloop (bool): + If ``True``, adds self-loops to adjacent matrices. Default: ``True``. + dropout (float): + The probability of feature vector elements to be zeroed. Default: 0. + norm (bool): + If ``True``, adds a :class:`~torch.nn.LayerNorm` layer after each GCN layer. Default: ``True``. + """ + + def __init__( + self, + n_model: int, + n_layers: int = 1, + selfloop: bool = True, + dropout: float = 0., + norm: bool = True + ) -> GraphConvolutionalNetwork: + super().__init__() + + self.n_model = n_model + self.n_layers = n_layers + self.selfloop = selfloop + self.norm = norm + + self.conv_layers = nn.ModuleList([ + nn.Sequential( + GraphConv(n_model), + nn.LayerNorm([n_model]) if norm else nn.Identity() + ) + for _ in range(n_layers) + ]) + self.dropout = nn.Dropout(dropout) + + def __repr__(self): + s = f"n_model={self.n_model}, n_layers={self.n_layers}" + if self.selfloop: + s += f", selfloop={self.selfloop}" + if self.dropout.p > 0: + s += f", dropout={self.dropout.p}" + if self.norm: + s += f", norm={self.norm}" + return f"{self.__class__.__name__}({s})" + + def forward(self, x: torch.Tensor, adj: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): + Node feature tensors of shape ``[batch_size, seq_len, n_model]``. + adj (~torch.Tensor): + Adjacent matrix of shape ``[batch_size, seq_len, seq_len]``. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask for covering the unpadded tokens in each chart. + + Returns: + ~torch.Tensor: + Node feature tensors of shape ``[batch_size, seq_len, n_model]``. + """ + + if self.selfloop: + adj.diagonal(0, 1, 2).fill_(1.) + adj = adj.masked_fill(~(mask.unsqueeze(1) & mask.unsqueeze(2)), 0) + for conv, norm in self.conv_layers: + x = norm(x + self.dropout(conv(x, adj).relu())) + return x + + +class GraphConv(nn.Module): + + def __init__(self, n_model: int, bias: bool = True) -> GraphConv: + super().__init__() + + self.n_model = n_model + + self.linear = nn.Linear(n_model, n_model, bias=False) + self.bias = nn.Parameter(torch.zeros(n_model)) if bias else None + + def __repr__(self): + s = f"n_model={self.n_model}" + if self.bias is not None: + s += ", bias=True" + return f"{self.__class__.__name__}({s})" + + def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): + Node feature tensors of shape ``[batch_size, seq_len, n_model]``. + adj (~torch.Tensor): + Adjacent matrix of shape ``[batch_size, seq_len, seq_len]``. + + Returns: + ~torch.Tensor: + Node feature tensors of shape ``[batch_size, seq_len, n_model]``. + """ + + x = self.linear(x) + x = torch.matmul(adj * (adj.sum(1, True) * adj.sum(2, True) + torch.finfo(adj.dtype).eps).pow(-0.5), x) + if self.bias is not None: + x = x + self.bias + return x diff --git a/tania_scripts/supar/modules/lstm.py b/tania_scripts/supar/modules/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..759f9c187f28ff6c8b662a6c6feee751198b83db --- /dev/null +++ b/tania_scripts/supar/modules/lstm.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from supar.modules.dropout import SharedDropout +from torch.nn.modules.rnn import apply_permutation +from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence + + +class CharLSTM(nn.Module): + r""" + CharLSTM aims to generate character-level embeddings for tokens. + It summarizes the information of characters in each token to an embedding using a LSTM layer. + + Args: + n_char (int): + The number of characters. + n_embed (int): + The size of each embedding vector as input to LSTM. + n_hidden (int): + The size of each LSTM hidden state. + n_out (int): + The size of each output vector. Default: 0. + If 0, equals to the size of hidden states. + pad_index (int): + The index of the padding token in the vocabulary. Default: 0. + dropout (float): + The dropout ratio of CharLSTM hidden states. Default: 0. + """ + + def __init__( + self, + n_chars: int, + n_embed: int, + n_hidden: int, + n_out: int = 0, + pad_index: int = 0, + dropout: float = 0 + ) -> CharLSTM: + super().__init__() + + self.n_chars = n_chars + self.n_embed = n_embed + self.n_hidden = n_hidden + self.n_out = n_out or n_hidden + self.pad_index = pad_index + + self.embed = nn.Embedding(num_embeddings=n_chars, embedding_dim=n_embed) + self.lstm = nn.LSTM(input_size=n_embed, hidden_size=n_hidden//2, batch_first=True, bidirectional=True) + self.dropout = nn.Dropout(p=dropout) + self.projection = nn.Linear(in_features=n_hidden, out_features=self.n_out) if n_hidden != self.n_out else nn.Identity() + + def __repr__(self): + s = f"{self.n_chars}, {self.n_embed}" + if self.n_hidden != self.n_out: + s += f", n_hidden={self.n_hidden}" + s += f", n_out={self.n_out}, pad_index={self.pad_index}" + if self.dropout.p != 0: + s += f", dropout={self.dropout.p}" + return f"{self.__class__.__name__}({s})" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): ``[batch_size, seq_len, fix_len]``. + Characters of all tokens. + Each token holds no more than `fix_len` characters, and the excess is cut off directly. + Returns: + ~torch.Tensor: + The embeddings of shape ``[batch_size, seq_len, n_out]`` derived from the characters. + """ + + # [batch_size, seq_len, fix_len] + mask = x.ne(self.pad_index) + # [batch_size, seq_len] + lens = mask.sum(-1) + char_mask = lens.gt(0) + + # [n, fix_len, n_embed] + x = self.embed(x[char_mask]) + x = pack_padded_sequence(x, lens[char_mask].tolist(), True, False) + x, (h, _) = self.lstm(x) + # [n, fix_len, n_hidden] + h = self.dropout(torch.cat(torch.unbind(h), -1)) + # [n, fix_len, n_out] + h = self.projection(h) + # [batch_size, seq_len, n_out] + return h.new_zeros(*lens.shape, self.n_out).masked_scatter_(char_mask.unsqueeze(-1), h) + + +class VariationalLSTM(nn.Module): + r""" + VariationalLSTM :cite:`yarin-etal-2016-dropout` is an variant of the vanilla bidirectional LSTM + adopted by Biaffine Parser with the only difference of the dropout strategy. + It drops nodes in the LSTM layers (input and recurrent connections) + and applies the same dropout mask at every recurrent timesteps. + + APIs are roughly the same as :class:`~torch.nn.LSTM` except that we only allows + :class:`~torch.nn.utils.rnn.PackedSequence` as input. + + Args: + input_size (int): + The number of expected features in the input. + hidden_size (int): + The number of features in the hidden state `h`. + num_layers (int): + The number of recurrent layers. Default: 1. + bidirectional (bool): + If ``True``, becomes a bidirectional LSTM. Default: ``False`` + dropout (float): + If non-zero, introduces a :class:`SharedDropout` layer on the outputs of each LSTM layer except the last layer. + Default: 0. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bidirectional: bool = False, + dropout: float = .0 + ) -> VariationalLSTM: + super().__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bidirectional = bidirectional + self.dropout = dropout + self.num_directions = 1 + self.bidirectional + + self.f_cells = nn.ModuleList() + if bidirectional: + self.b_cells = nn.ModuleList() + for _ in range(self.num_layers): + self.f_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)) + if bidirectional: + self.b_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)) + input_size = hidden_size * self.num_directions + + self.reset_parameters() + + def __repr__(self): + s = f"{self.input_size}, {self.hidden_size}" + if self.num_layers > 1: + s += f", num_layers={self.num_layers}" + if self.bidirectional: + s += f", bidirectional={self.bidirectional}" + if self.dropout > 0: + s += f", dropout={self.dropout}" + return f"{self.__class__.__name__}({s})" + + def reset_parameters(self): + for param in self.parameters(): + # apply orthogonal_ to weight + if len(param.shape) > 1: + nn.init.orthogonal_(param) + # apply zeros_ to bias + else: + nn.init.zeros_(param) + + def permute_hidden( + self, + hx: Tuple[torch.Tensor, torch.Tensor], + permutation: torch.LongTensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if permutation is None: + return hx + h = apply_permutation(hx[0], permutation) + c = apply_permutation(hx[1], permutation) + + return h, c + + def layer_forward( + self, + x: List[torch.Tensor], + hx: Tuple[torch.Tensor, torch.Tensor], + cell: nn.LSTMCell, + batch_sizes: List[int], + reverse: bool = False + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + hx_0 = hx_i = hx + hx_n, output = [], [] + steps = reversed(range(len(x))) if reverse else range(len(x)) + if self.training: + hid_mask = SharedDropout.get_mask(hx_0[0], self.dropout) + + for t in steps: + last_batch_size, batch_size = len(hx_i[0]), batch_sizes[t] + if last_batch_size < batch_size: + hx_i = [torch.cat((h, ih[last_batch_size:batch_size])) for h, ih in zip(hx_i, hx_0)] + else: + hx_n.append([h[batch_size:] for h in hx_i]) + hx_i = [h[:batch_size] for h in hx_i] + hx_i = [h for h in cell(x[t], hx_i)] + output.append(hx_i[0]) + if self.training: + hx_i[0] = hx_i[0] * hid_mask[:batch_size] + if reverse: + hx_n = hx_i + output.reverse() + else: + hx_n.append(hx_i) + hx_n = [torch.cat(h) for h in zip(*reversed(hx_n))] + output = torch.cat(output) + + return output, hx_n + + def forward( + self, + sequence: PackedSequence, + hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + ) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: + r""" + Args: + sequence (~torch.nn.utils.rnn.PackedSequence): + A packed variable length sequence. + hx (~torch.Tensor, ~torch.Tensor): + A tuple composed of two tensors `h` and `c`. + `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial hidden state + for each element in the batch. + `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial cell state + for each element in the batch. + If `hx` is not provided, both `h` and `c` default to zero. + Default: ``None``. + + Returns: + ~torch.nn.utils.rnn.PackedSequence, (~torch.Tensor, ~torch.Tensor): + The first is a packed variable length sequence. + The second is a tuple of tensors `h` and `c`. + `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the hidden state for `t=seq_len`. + Like output, the layers can be separated using ``h.view(num_layers, num_directions, batch_size, hidden_size)`` + and similarly for c. + `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the cell state for `t=seq_len`. + """ + x, batch_sizes = sequence.data, sequence.batch_sizes.tolist() + batch_size = batch_sizes[0] + h_n, c_n = [], [] + + if hx is None: + ih = x.new_zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size) + h, c = ih, ih + else: + h, c = self.permute_hidden(hx, sequence.sorted_indices) + h = h.view(self.num_layers, self.num_directions, batch_size, self.hidden_size) + c = c.view(self.num_layers, self.num_directions, batch_size, self.hidden_size) + + for i in range(self.num_layers): + x = torch.split(x, batch_sizes) + if self.training: + mask = SharedDropout.get_mask(x[0], self.dropout) + x = [i * mask[:len(i)] for i in x] + x_i, (h_i, c_i) = self.layer_forward(x, (h[i, 0], c[i, 0]), self.f_cells[i], batch_sizes) + if self.bidirectional: + x_b, (h_b, c_b) = self.layer_forward(x, (h[i, 1], c[i, 1]), self.b_cells[i], batch_sizes, True) + x_i = torch.cat((x_i, x_b), -1) + h_i = torch.stack((h_i, h_b)) + c_i = torch.stack((c_i, c_b)) + x = x_i + h_n.append(h_i) + c_n.append(c_i) + + x = PackedSequence(x, sequence.batch_sizes, sequence.sorted_indices, sequence.unsorted_indices) + hx = torch.cat(h_n, 0), torch.cat(c_n, 0) + hx = self.permute_hidden(hx, sequence.unsorted_indices) + + return x, hx diff --git a/tania_scripts/supar/modules/mlp.py b/tania_scripts/supar/modules/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..f26cdd983b70e80110a81253387f30c3c1c79ead --- /dev/null +++ b/tania_scripts/supar/modules/mlp.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import torch +import torch.nn as nn +from supar.modules.dropout import SharedDropout + + +class MLP(nn.Module): + r""" + Applies a linear transformation together with a non-linear activation to the incoming tensor: + :math:`y = \mathrm{Activation}(x A^T + b)` + + Args: + n_in (~torch.Tensor): + The size of each input feature. + n_out (~torch.Tensor): + The size of each output feature. + dropout (float): + If non-zero, introduces a :class:`SharedDropout` layer on the output with this dropout ratio. Default: 0. + activation (bool): + Whether to use activations. Default: True. + """ + + def __init__(self, n_in: int, n_out: int, dropout: float = .0, activation: bool = True) -> MLP: + super().__init__() + + self.n_in = n_in + self.n_out = n_out + self.linear = nn.Linear(n_in, n_out) + self.activation = nn.LeakyReLU(negative_slope=0.1) if activation else nn.Identity() + self.dropout = SharedDropout(p=dropout) + + self.reset_parameters() + + def __repr__(self): + s = f"n_in={self.n_in}, n_out={self.n_out}" + if self.dropout.p > 0: + s += f", dropout={self.dropout.p}" + + return f"{self.__class__.__name__}({s})" + + def reset_parameters(self): + nn.init.orthogonal_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Args: + x (~torch.Tensor): + The size of each input feature is `n_in`. + + Returns: + A tensor with the size of each output feature `n_out`. + """ + + x = self.linear(x) + x = self.activation(x) + x = self.dropout(x) + + return x diff --git a/tania_scripts/supar/modules/pretrained.py b/tania_scripts/supar/modules/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..4805c88474a42bd3fde1cab6c9fc12f0b8b03558 --- /dev/null +++ b/tania_scripts/supar/modules/pretrained.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import List, Tuple + +import torch +import torch.nn as nn +from supar.utils.fn import pad +from supar.utils.tokenizer import TransformerTokenizer + + +class TransformerEmbedding(nn.Module): + r""" + Bidirectional transformer embeddings of words from various transformer architectures :cite:`devlin-etal-2019-bert`. + + Args: + name (str): + Path or name of the pretrained models registered in `transformers`_, e.g., ``'bert-base-cased'``. + n_layers (int): + The number of BERT layers to use. If 0, uses all layers. + n_out (int): + The requested size of the embeddings. If 0, uses the size of the pretrained embedding model. Default: 0. + stride (int): + A sequence longer than max length will be splitted into several small pieces + with a window size of ``stride``. Default: 10. + pooling (str): + Pooling way to get from token piece embeddings to token embedding. + ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. + Default: ``mean``. + pad_index (int): + The index of the padding token in BERT vocabulary. Default: 0. + mix_dropout (float): + The dropout ratio of BERT layers. This value will be passed into the :class:`ScalarMix` layer. Default: 0. + finetune (bool): + If ``True``, the model parameters will be updated together with the downstream task. Default: ``False``. + + .. _transformers: + https://github.com/huggingface/transformers + """ + + def __init__( + self, + name: str, + n_layers: int, + n_out: int = 0, + stride: int = 256, + pooling: str = 'mean', + pad_index: int = 0, + mix_dropout: float = .0, + finetune: bool = False + ) -> TransformerEmbedding: + super().__init__() + + from transformers import AutoModel + try: + self.model = AutoModel.from_pretrained(name, output_hidden_states=True, local_files_only=True) + except Exception: + self.model = AutoModel.from_pretrained(name, output_hidden_states=True, local_files_only=False) + self.model = self.model.requires_grad_(finetune) + self.tokenizer = TransformerTokenizer(name) + + self.name = name + self.n_layers = n_layers or self.model.config.num_hidden_layers + self.hidden_size = self.model.config.hidden_size + self.n_out = n_out or self.hidden_size + self.pooling = pooling + self.pad_index = pad_index + self.mix_dropout = mix_dropout + self.finetune = finetune + try: + self.max_len = int(max(0, self.model.config.max_position_embeddings) or 1e12) - 2 + except: + self.max_len = 512 + self.stride = min(stride, self.max_len) + + self.scalar_mix = ScalarMix(self.n_layers, mix_dropout) + self.projection = nn.Linear(self.hidden_size, self.n_out, False) if self.hidden_size != n_out else nn.Identity() + + def __repr__(self): + s = f"{self.name}, n_layers={self.n_layers}, n_out={self.n_out}, " + s += f"stride={self.stride}, pooling={self.pooling}, pad_index={self.pad_index}" + if self.mix_dropout > 0: + s += f", mix_dropout={self.mix_dropout}" + if self.finetune: + s += f", finetune={self.finetune}" + return f"{self.__class__.__name__}({s})" + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + r""" + Args: + tokens (~torch.Tensor): ``[batch_size, seq_len, fix_len]``. + + Returns: + ~torch.Tensor: + Contextualized token embeddings of shape ``[batch_size, seq_len, n_out]``. + """ + + mask = tokens.ne(self.pad_index) + lens = mask.sum((1, 2)) + # [batch_size, n_subwords] + tokens = pad(tokens[mask].split(lens.tolist()), self.pad_index, padding_side=self.tokenizer.padding_side) + token_mask = pad(mask[mask].split(lens.tolist()), 0, padding_side=self.tokenizer.padding_side) + + # return the hidden states of all layers + x = self.model(tokens[:, :self.max_len], attention_mask=token_mask[:, :self.max_len].float())[-1] + # [batch_size, max_len, hidden_size] + x = self.scalar_mix(x[-self.n_layers:]) + # [batch_size, n_subwords, hidden_size] + for i in range(self.stride, (tokens.shape[1]-self.max_len+self.stride-1)//self.stride*self.stride+1, self.stride): + part = self.model(tokens[:, i:i+self.max_len], attention_mask=token_mask[:, i:i+self.max_len].float())[-1] + x = torch.cat((x, self.scalar_mix(part[-self.n_layers:])[:, self.max_len-self.stride:]), 1) + # [batch_size, seq_len] + lens = mask.sum(-1) + lens = lens.masked_fill_(lens.eq(0), 1) + # [batch_size, seq_len, fix_len, hidden_size] + x = x.new_zeros(*mask.shape, self.hidden_size).masked_scatter_(mask.unsqueeze(-1), x[token_mask]) + # [batch_size, seq_len, hidden_size] + if self.pooling == 'first': + x = x[:, :, 0] + elif self.pooling == 'last': + x = x.gather(2, (lens-1).unsqueeze(-1).repeat(1, 1, self.hidden_size).unsqueeze(2)).squeeze(2) + elif self.pooling == 'mean': + x = x.sum(2) / lens.unsqueeze(-1) + else: + raise RuntimeError(f'Unsupported pooling method "{self.pooling}"!') + return self.projection(x) + + +class ELMoEmbedding(nn.Module): + r""" + Contextual word embeddings using word-level bidirectional LM :cite:`peters-etal-2018-deep`. + + Args: + name (str): + The name of the pretrained ELMo registered in `OPTION` and `WEIGHT`. Default: ``'original_5b'``. + bos_eos (Tuple[bool]): + A tuple of two boolean values indicating whether to keep start/end boundaries of sentence outputs. + Default: ``(True, True)``. + n_out (int): + The requested size of the embeddings. If 0, uses the default size of ELMo outputs. Default: 0. + dropout (float): + The dropout ratio for the ELMo layer. Default: 0. + finetune (bool): + If ``True``, the model parameters will be updated together with the downstream task. Default: ``False``. + """ + + OPTION = { + 'small': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json', # noqa + 'medium': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json', # noqa + 'original': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json', # noqa + 'original_5b': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json', # noqa + } + WEIGHT = { + 'small': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5', # noqa + 'medium': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5', # noqa + 'original': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5', # noqa + 'original_5b': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5', # noqa + } + + def __init__( + self, + name: str = 'original_5b', + bos_eos: Tuple[bool, bool] = (True, True), + n_out: int = 0, + dropout: float = 0.5, + finetune: bool = False + ) -> ELMoEmbedding: + super().__init__() + + from allennlp.modules import Elmo + + self.elmo = Elmo(options_file=self.OPTION[name], + weight_file=self.WEIGHT[name], + num_output_representations=1, + dropout=dropout, + finetune=finetune, + keep_sentence_boundaries=True) + + self.name = name + self.bos_eos = bos_eos + self.hidden_size = self.elmo.get_output_dim() + self.n_out = n_out or self.hidden_size + self.dropout = dropout + self.finetune = finetune + + self.projection = nn.Linear(self.hidden_size, self.n_out, False) if self.hidden_size != n_out else nn.Identity() + + def __repr__(self): + s = f"{self.name}, n_out={self.n_out}" + if self.dropout > 0: + s += f", dropout={self.dropout}" + if self.finetune: + s += f", finetune={self.finetune}" + return f"{self.__class__.__name__}({s})" + + def forward(self, chars: torch.LongTensor) -> torch.Tensor: + r""" + Args: + chars (~torch.LongTensor): ``[batch_size, seq_len, fix_len]``. + + Returns: + ~torch.Tensor: + ELMo embeddings of shape ``[batch_size, seq_len, n_out]``. + """ + + x = self.projection(self.elmo(chars)['elmo_representations'][0]) + if not self.bos_eos[0]: + x = x[:, 1:] + if not self.bos_eos[1]: + x = x[:, :-1] + return x + + +class ScalarMix(nn.Module): + r""" + Computes a parameterized scalar mixture of :math:`N` tensors, :math:`mixture = \gamma * \sum_{k}(s_k * tensor_k)` + where :math:`s = \mathrm{softmax}(w)`, with :math:`w` and :math:`\gamma` scalar parameters. + + Args: + n_layers (int): + The number of layers to be mixed, i.e., :math:`N`. + dropout (float): + The dropout ratio of the layer weights. + If dropout > 0, then for each scalar weight, adjusts its softmax weight mass to 0 + with the dropout probability (i.e., setting the unnormalized weight to -inf). + This effectively redistributes the dropped probability mass to all other weights. + Default: 0. + """ + + def __init__(self, n_layers: int, dropout: float = .0) -> ScalarMix: + super().__init__() + + self.n_layers = n_layers + + self.weights = nn.Parameter(torch.zeros(n_layers)) + self.gamma = nn.Parameter(torch.tensor([1.0])) + self.dropout = nn.Dropout(dropout) + + def __repr__(self): + s = f"n_layers={self.n_layers}" + if self.dropout.p > 0: + s += f", dropout={self.dropout.p}" + return f"{self.__class__.__name__}({s})" + + def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: + r""" + Args: + tensors (List[~torch.Tensor]): + :math:`N` tensors to be mixed. + + Returns: + The mixture of :math:`N` tensors. + """ + + return self.gamma * sum(w * h for w, h in zip(self.dropout(self.weights.softmax(-1)), tensors)) diff --git a/tania_scripts/supar/modules/transformer.py b/tania_scripts/supar/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7c967fee3e99bdbe4a81f8542267a93465dca7 --- /dev/null +++ b/tania_scripts/supar/modules/transformer.py @@ -0,0 +1,585 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import copy +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TransformerWordEmbedding(nn.Module): + + def __init__( + self, + n_vocab: int = None, + n_embed: int = None, + embed_scale: Optional[int] = None, + max_len: Optional[int] = 512, + pos: Optional[str] = None, + pad_index: Optional[int] = None, + ) -> TransformerWordEmbedding: + super(TransformerWordEmbedding, self).__init__() + + self.embed = nn.Embedding(num_embeddings=n_vocab, + embedding_dim=n_embed) + if pos is None: + self.pos_embed = nn.Identity() + elif pos == 'sinusoid': + self.pos_embed = SinusoidPositionalEmbedding() + elif pos == 'sinusoid_relative': + self.pos_embed = SinusoidRelativePositionalEmbedding() + elif pos == 'learnable': + self.pos_embed = PositionalEmbedding(max_len=max_len) + elif pos == 'learnable_relative': + self.pos_embed = RelativePositionalEmbedding(max_len=max_len) + else: + raise ValueError(f'Unknown positional embedding type {pos}') + + self.n_vocab = n_vocab + self.n_embed = n_embed + self.embed_scale = embed_scale or n_embed ** 0.5 + self.max_len = max_len + self.pos = pos + self.pad_index = pad_index + + self.reset_parameters() + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += f"{self.n_vocab}, {self.n_embed}" + if self.embed_scale is not None: + s += f", embed_scale={self.embed_scale:.2f}" + if self.max_len is not None: + s += f", max_len={self.max_len}" + if self.pos is not None: + s += f", pos={self.pos}" + if self.pad_index is not None: + s += f", pad_index={self.pad_index}" + s += ')' + return s + + def reset_parameters(self): + nn.init.normal_(self.embed.weight, 0, self.n_embed ** -0.5) + if self.pad_index is not None: + nn.init.zeros_(self.embed.weight[self.pad_index]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.embed(x) + if self.embed_scale: + x = x * self.embed_scale + if self.pos is not None: + x = x + self.pos_embed(x) + return x + + +class TransformerEncoder(nn.Module): + + def __init__( + self, + layer: nn.Module, + n_layers: int = 6, + n_model: int = 1024, + pre_norm: bool = False, + ) -> TransformerEncoder: + super(TransformerEncoder, self).__init__() + + self.n_layers = n_layers + self.n_model = n_model + self.pre_norm = pre_norm + + self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)]) + self.norm = nn.LayerNorm(n_model) if self.pre_norm else None + + def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + x = x.transpose(0, 1) + for layer in self.layers: + x = layer(x, mask) + if self.pre_norm: + x = self.norm(x) + return x.transpose(0, 1) + + +class TransformerDecoder(nn.Module): + + def __init__( + self, + layer: nn.Module, + n_layers: int = 6, + n_model: int = 1024, + pre_norm: bool = False, + ) -> TransformerDecoder: + super(TransformerDecoder, self).__init__() + + self.n_layers = n_layers + self.n_model = n_model + self.pre_norm = pre_norm + + self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)]) + self.norm = nn.LayerNorm(n_model) if self.pre_norm else None + + def forward( + self, + x_tgt: torch.Tensor, + x_src: torch.Tensor, + tgt_mask: torch.BoolTensor, + src_mask: torch.BoolTensor, + attn_mask: Optional[torch.BoolTensor] = None + ) -> torch.Tensor: + x_tgt, x_src = x_tgt.transpose(0, 1), x_src.transpose(0, 1) + for layer in self.layers: + x_tgt = layer(x_tgt=x_tgt, + x_src=x_src, + tgt_mask=tgt_mask, + src_mask=src_mask, + attn_mask=attn_mask) + if self.pre_norm: + x_tgt = self.norm(x_tgt) + return x_tgt.transpose(0, 1) + + +class TransformerEncoderLayer(nn.Module): + + def __init__( + self, + n_heads: int = 8, + n_model: int = 1024, + n_inner: int = 2048, + activation: str = 'relu', + bias: bool = True, + pre_norm: bool = False, + attn_dropout: float = 0.1, + ffn_dropout: float = 0.1, + dropout: float = 0.1 + ) -> TransformerEncoderLayer: + super(TransformerEncoderLayer, self).__init__() + + self.attn = MultiHeadAttention(n_heads=n_heads, + n_model=n_model, + n_embed=n_model//n_heads, + dropout=attn_dropout, + bias=bias) + self.attn_norm = nn.LayerNorm(n_model) + self.ffn = PositionwiseFeedForward(n_model=n_model, + n_inner=n_inner, + activation=activation, + dropout=ffn_dropout) + self.ffn_norm = nn.LayerNorm(n_model) + self.dropout = nn.Dropout(dropout) + + self.pre_norm = pre_norm + + def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + if self.pre_norm: + n = self.attn_norm(x) + x = x + self.dropout(self.attn(n, n, n, mask)) + n = self.ffn_norm(x) + x = x + self.dropout(self.ffn(n)) + else: + x = self.attn_norm(x + self.dropout(self.attn(x, x, x, mask))) + x = self.ffn_norm(x + self.dropout(self.ffn(x))) + return x + + +class RelativePositionTransformerEncoderLayer(nn.Module): + + def __init__( + self, + n_heads: int = 8, + n_model: int = 1024, + n_inner: int = 2048, + activation: str = 'relu', + pre_norm: bool = False, + attn_dropout: float = 0.1, + ffn_dropout: float = 0.1, + dropout: float = 0.1 + ) -> RelativePositionTransformerEncoderLayer: + super(RelativePositionTransformerEncoderLayer, self).__init__() + + self.attn = RelativePositionMultiHeadAttention(n_heads=n_heads, + n_model=n_model, + n_embed=n_model//n_heads, + dropout=attn_dropout) + self.attn_norm = nn.LayerNorm(n_model) + self.ffn = PositionwiseFeedForward(n_model=n_model, + n_inner=n_inner, + activation=activation, + dropout=ffn_dropout) + self.ffn_norm = nn.LayerNorm(n_model) + self.dropout = nn.Dropout(dropout) + + self.pre_norm = pre_norm + + def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + if self.pre_norm: + n = self.attn_norm(x) + x = x + self.dropout(self.attn(n, n, n, mask)) + n = self.ffn_norm(x) + x = x + self.dropout(self.ffn(n)) + else: + x = self.attn_norm(x + self.dropout(self.attn(x, x, x, mask))) + x = self.ffn_norm(x + self.dropout(self.ffn(x))) + return x + + +class TransformerDecoderLayer(nn.Module): + + def __init__( + self, + n_heads: int = 8, + n_model: int = 1024, + n_inner: int = 2048, + activation: str = 'relu', + bias: bool = True, + pre_norm: bool = False, + attn_dropout: float = 0.1, + ffn_dropout: float = 0.1, + dropout: float = 0.1 + ) -> TransformerDecoderLayer: + super(TransformerDecoderLayer, self).__init__() + + self.self_attn = MultiHeadAttention(n_heads=n_heads, + n_model=n_model, + n_embed=n_model//n_heads, + dropout=attn_dropout, + bias=bias) + self.self_attn_norm = nn.LayerNorm(n_model) + self.mha_attn = MultiHeadAttention(n_heads=n_heads, + n_model=n_model, + n_embed=n_model//n_heads, + dropout=attn_dropout, + bias=bias) + self.mha_attn_norm = nn.LayerNorm(n_model) + self.ffn = PositionwiseFeedForward(n_model=n_model, + n_inner=n_inner, + activation=activation, + dropout=ffn_dropout) + self.ffn_norm = nn.LayerNorm(n_model) + self.dropout = nn.Dropout(dropout) + + self.pre_norm = pre_norm + + def forward( + self, + x_tgt: torch.Tensor, + x_src: torch.Tensor, + tgt_mask: torch.BoolTensor, + src_mask: torch.BoolTensor, + attn_mask: Optional[torch.BoolTensor] = None + ) -> torch.Tensor: + if self.pre_norm: + n_tgt = self.self_attn_norm(x_tgt) + x_tgt = x_tgt + self.dropout(self.self_attn(n_tgt, n_tgt, n_tgt, tgt_mask, attn_mask)) + n_tgt = self.mha_attn_norm(x_tgt) + x_tgt = x_tgt + self.dropout(self.mha_attn(n_tgt, x_src, x_src, src_mask)) + n_tgt = self.ffn_norm(x_tgt) + x_tgt = x_tgt + self.dropout(self.ffn(x_tgt)) + else: + x_tgt = self.self_attn_norm(x_tgt + self.dropout(self.self_attn(x_tgt, x_tgt, x_tgt, tgt_mask, attn_mask))) + x_tgt = self.mha_attn_norm(x_tgt + self.dropout(self.mha_attn(x_tgt, x_src, x_src, src_mask))) + x_tgt = self.ffn_norm(x_tgt + self.dropout(self.ffn(x_tgt))) + return x_tgt + + +class RelativePositionTransformerDecoderLayer(nn.Module): + + def __init__( + self, + n_heads: int = 8, + n_model: int = 1024, + n_inner: int = 2048, + activation: str = 'relu', + pre_norm: bool = False, + attn_dropout: float = 0.1, + ffn_dropout: float = 0.1, + dropout: float = 0.1 + ) -> RelativePositionTransformerDecoderLayer: + super(RelativePositionTransformerDecoderLayer, self).__init__() + + self.self_attn = RelativePositionMultiHeadAttention(n_heads=n_heads, + n_model=n_model, + n_embed=n_model//n_heads, + dropout=attn_dropout) + self.self_attn_norm = nn.LayerNorm(n_model) + self.mha_attn = RelativePositionMultiHeadAttention(n_heads=n_heads, + n_model=n_model, + n_embed=n_model//n_heads, + dropout=attn_dropout) + self.mha_attn_norm = nn.LayerNorm(n_model) + self.ffn = PositionwiseFeedForward(n_model=n_model, + n_inner=n_inner, + activation=activation, + dropout=ffn_dropout) + self.ffn_norm = nn.LayerNorm(n_model) + self.dropout = nn.Dropout(dropout) + + self.pre_norm = pre_norm + + def forward( + self, + x_tgt: torch.Tensor, + x_src: torch.Tensor, + tgt_mask: torch.BoolTensor, + src_mask: torch.BoolTensor, + attn_mask: Optional[torch.BoolTensor] = None + ) -> torch.Tensor: + if self.pre_norm: + n_tgt = self.self_attn_norm(x_tgt) + x_tgt = x_tgt + self.dropout(self.self_attn(n_tgt, n_tgt, n_tgt, tgt_mask, attn_mask)) + n_tgt = self.mha_attn_norm(x_tgt) + x_tgt = x_tgt + self.dropout(self.mha_attn(n_tgt, x_src, x_src, src_mask)) + n_tgt = self.ffn_norm(x_tgt) + x_tgt = x_tgt + self.dropout(self.ffn(x_tgt)) + else: + x_tgt = self.self_attn_norm(x_tgt + self.dropout(self.self_attn(x_tgt, x_tgt, x_tgt, tgt_mask, attn_mask))) + x_tgt = self.mha_attn_norm(x_tgt + self.dropout(self.mha_attn(x_tgt, x_src, x_src, src_mask))) + x_tgt = self.ffn_norm(x_tgt + self.dropout(self.ffn(x_tgt))) + return x_tgt + + +class MultiHeadAttention(nn.Module): + + def __init__( + self, + n_heads: int = 8, + n_model: int = 1024, + n_embed: int = 128, + dropout: float = 0.1, + bias: bool = True, + attn: bool = False, + ) -> MultiHeadAttention: + super(MultiHeadAttention, self).__init__() + + self.n_heads = n_heads + self.n_model = n_model + self.n_embed = n_embed + self.scale = n_embed**0.5 + + self.wq = nn.Linear(n_model, n_heads * n_embed, bias=bias) + self.wk = nn.Linear(n_model, n_heads * n_embed, bias=bias) + self.wv = nn.Linear(n_model, n_heads * n_embed, bias=bias) + self.wo = nn.Linear(n_heads * n_embed, n_model, bias=bias) + self.dropout = nn.Dropout(dropout) + + self.bias = bias + self.attn = attn + + self.reset_parameters() + + def reset_parameters(self): + # borrowed from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/multihead_attention.py + nn.init.xavier_uniform_(self.wq.weight, 2 ** -0.5) + nn.init.xavier_uniform_(self.wk.weight, 2 ** -0.5) + nn.init.xavier_uniform_(self.wv.weight, 2 ** -0.5) + nn.init.xavier_uniform_(self.wo.weight) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.BoolTensor, + attn_mask: Optional[torch.BoolTensor] = None + ) -> torch.Tensor: + batch_size, _ = mask.shape + # [seq_len, batch_size * n_heads, n_embed] + q = self.wq(q).view(-1, batch_size * self.n_heads, self.n_embed) + # [src_len, batch_size * n_heads, n_embed] + k = self.wk(k).view(-1, batch_size * self.n_heads, self.n_embed) + v = self.wv(v).view(-1, batch_size * self.n_heads, self.n_embed) + + mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1).view(-1, 1, *mask.shape[1:]) + # [batch_size * n_heads, seq_len, src_len] + if attn_mask is not None: + mask = mask & attn_mask + # [batch_size * n_heads, seq_len, src_len] + attn = torch.bmm(q.transpose(0, 1) / self.scale, k.movedim((0, 1), (2, 0))) + attn = torch.softmax(attn + torch.where(mask, 0., float('-inf')), -1) + attn = self.dropout(attn) + # [seq_len, batch_size * n_heads, n_embed] + x = torch.bmm(attn, v.transpose(0, 1)).transpose(0, 1) + # [seq_len, batch_size, n_model] + x = self.wo(x.reshape(-1, batch_size, self.n_heads * self.n_embed)) + + return (x, attn.view(batch_size, self.n_heads, *attn.shape[1:])) if self.attn else x + + +class RelativePositionMultiHeadAttention(nn.Module): + + def __init__( + self, + n_heads: int = 8, + n_model: int = 1024, + n_embed: int = 128, + dropout: float = 0.1, + attn: bool = False + ) -> RelativePositionMultiHeadAttention: + super(RelativePositionMultiHeadAttention, self).__init__() + + self.n_heads = n_heads + self.n_model = n_model + self.n_embed = n_embed + self.scale = n_embed**0.5 + + self.pos_embed = RelativePositionalEmbedding(n_model=n_embed) + self.wq = nn.Parameter(torch.zeros(n_model, n_heads * n_embed)) + self.wk = nn.Parameter(torch.zeros(n_model, n_heads * n_embed)) + self.wv = nn.Parameter(torch.zeros(n_model, n_heads * n_embed)) + self.wo = nn.Parameter(torch.zeros(n_heads * n_embed, n_model)) + self.bu = nn.Parameter(torch.zeros(n_heads, n_embed)) + self.bv = nn.Parameter(torch.zeros(n_heads, n_embed)) + self.dropout = nn.Dropout(dropout) + + self.attn = attn + + self.reset_parameters() + + def reset_parameters(self): + # borrowed from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/multihead_attention.py + nn.init.xavier_uniform_(self.wq, 2 ** -0.5) + nn.init.xavier_uniform_(self.wk, 2 ** -0.5) + nn.init.xavier_uniform_(self.wv, 2 ** -0.5) + nn.init.xavier_uniform_(self.wo) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.BoolTensor, + attn_mask: Optional[torch.BoolTensor] = None + ) -> torch.Tensor: + batch_size, _ = mask.shape + # [seq_len, batch_size, n_heads, n_embed] + q = F.linear(q, self.wq).view(-1, batch_size, self.n_heads, self.n_embed) + # [src_len, batch_size * n_heads, n_embed] + k = F.linear(k, self.wk).view(-1, batch_size * self.n_heads, self.n_embed) + v = F.linear(v, self.wv).view(-1, batch_size * self.n_heads, self.n_embed) + # [seq_len, src_len, n_embed] + p = self.pos_embed(q[:, 0, 0], k[:, 0]) + # [seq_len, batch_size * n_heads, n_embed] + qu, qv = (q + self.bu).view(-1, *k.shape[1:]), (q + self.bv).view(-1, *k.shape[1:]) + + mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1).view(-1, 1, *mask.shape[1:]) + if attn_mask is not None: + mask = mask & attn_mask + # [batch_size * n_heads, seq_len, src_len] + attn = torch.bmm(qu.transpose(0, 1), k.movedim((0, 1), (2, 0))) + attn = attn + torch.matmul(qv.transpose(0, 1).unsqueeze(2), p.transpose(1, 2)).squeeze(2) + attn = torch.softmax(attn / self.scale + torch.where(mask, 0., float('-inf')), -1) + attn = self.dropout(attn) + # [seq_len, batch_size * n_heads, n_embed] + x = torch.bmm(attn, v.transpose(0, 1)).transpose(0, 1) + # [seq_len, batch_size, n_model] + x = F.linear(x.reshape(-1, batch_size, self.n_heads * self.n_embed), self.wo) + + return (x, attn.view(batch_size, self.n_heads, *attn.shape[1:])) if self.attn else x + + +class PositionwiseFeedForward(nn.Module): + + def __init__( + self, + n_model: int = 1024, + n_inner: int = 2048, + activation: str = 'relu', + dropout: float = 0.1 + ) -> PositionwiseFeedForward: + super(PositionwiseFeedForward, self).__init__() + + self.w1 = nn.Linear(n_model, n_inner) + self.activation = nn.ReLU() if activation == 'relu' else nn.GELU() + self.dropout = nn.Dropout(dropout) + self.w2 = nn.Linear(n_inner, n_model) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.w1.weight) + nn.init.xavier_uniform_(self.w2.weight) + nn.init.zeros_(self.w1.bias) + nn.init.zeros_(self.w2.bias) + + def forward(self, x): + x = self.w1(x) + x = self.activation(x) + x = self.dropout(x) + x = self.w2(x) + + return x + + +class PositionalEmbedding(nn.Module): + + def __init__( + self, + n_model: int = 1024, + max_len: int = 1024 + ) -> PositionalEmbedding: + super().__init__() + + self.embed = nn.Embedding(max_len, n_model) + + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + w = self.embed.weight + max_len, n_model = w.shape + w = w.new_tensor(range(max_len)).unsqueeze(-1) + w = w / 10000 ** (w.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model) + w[:, 0::2], w[:, 1::2] = w[:, 0::2].sin(), w[:, 1::2].cos() + self.embed.weight.copy_(w) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.embed(x.new_tensor(range(x.shape[1])).long()) + + +class RelativePositionalEmbedding(nn.Module): + + def __init__( + self, + n_model: int = 1024, + max_len: int = 1024 + ) -> RelativePositionalEmbedding: + super().__init__() + + self.embed = nn.Embedding(max_len, n_model) + + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + w = self.embed.weight + max_len, n_model = w.shape + pos = torch.cat((w.new_tensor(range(-max_len//2, 0)), w.new_tensor(range(max_len//2)))) + w = pos.unsqueeze(-1) / 10000 ** (w.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model) + w[:, 0::2], w[:, 1::2] = w[:, 0::2].sin(), w[:, 1::2].cos() + self.embed.weight.copy_(w) + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + offset = sum(divmod(self.embed.weight.shape[0], 2)) + return self.embed((k.new_tensor(range(k.shape[0])) - q.new_tensor(range(q.shape[0])).unsqueeze(-1)).long() + offset) + + +class SinusoidPositionalEmbedding(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + seq_len, n_model = x[0].shape + pos = x.new_tensor(range(seq_len)).unsqueeze(-1) + pos = pos / 10000 ** (x.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model) + pos[:, 0::2], pos[:, 1::2] = pos[:, 0::2].sin(), pos[:, 1::2].cos() + return pos + + +class SinusoidRelativePositionalEmbedding(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + seq_len, n_model = x[0].shape + pos = x.new_tensor(range(seq_len)) + pos = (pos - pos.unsqueeze(-1)).unsqueeze(-1) + pos = pos / 10000 ** (x.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model) + pos[..., 0::2], pos[..., 1::2] = pos[..., 0::2].sin(), pos[..., 1::2].cos() + return pos diff --git a/tania_scripts/supar/parser.py b/tania_scripts/supar/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f1c372e52475d87780f414396606f71a7494653b --- /dev/null +++ b/tania_scripts/supar/parser.py @@ -0,0 +1,620 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import contextlib +import os +import shutil +import sys +import tempfile +import pickle +from contextlib import contextmanager +from datetime import datetime, timedelta +from typing import Any, Iterable, Union + +import dill +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.cuda.amp import GradScaler +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler + +import supar +from supar.utils import Config, Dataset +from supar.utils.field import Field +from supar.utils.fn import download, get_rng_state, set_rng_state +from supar.utils.logging import get_logger, init_logger, progress_bar +from supar.utils.metric import Metric +from supar.utils.optim import InverseSquareRootLR, LinearLR +from supar.utils.parallel import DistributedDataParallel as DDP +from supar.utils.parallel import gather, is_dist, is_master, reduce +from supar.utils.transform import Batch + +logger = get_logger(__name__) + + +class Parser(object): + + NAME = None + MODEL = None + + def __init__(self, args, model, transform): + self.args = args + self.model = model + self.transform = transform + + @property + def device(self): + return 'cuda' if torch.cuda.is_available() else 'cpu' + + @property + def sync_grad(self): + return self.step % self.args.update_steps == 0 or self.step % self.n_batches == 0 + + @contextmanager + def sync(self): + context = getattr(contextlib, 'suppress' if sys.version < '3.7' else 'nullcontext') + if is_dist() and not self.sync_grad: + context = self.model.no_sync + with context(): + yield + + @contextmanager + def join(self): + context = getattr(contextlib, 'suppress' if sys.version < '3.7' else 'nullcontext') + if not is_dist(): + with context(): + yield + elif self.model.training: + with self.model.join(): + yield + else: + try: + dist_model = self.model + # https://github.com/pytorch/pytorch/issues/54059 + if hasattr(self.model, 'module'): + self.model = self.model.module + yield + finally: + self.model = dist_model + + def train( + self, + train: Union[str, Iterable], + dev: Union[str, Iterable], + test: Union[str, Iterable], + epochs: int, + patience: int, + batch_size: int = 5000, + update_steps: int = 1, + buckets: int = 32, + workers: int = 0, + clip: float = 5.0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ) -> None: + r""" + Args: + train/dev/test (Union[str, Iterable]): + Filenames of the train/dev/test datasets. + epochs (int): + The number of training iterations. + patience (int): + The number of consecutive iterations after which the training process would be early stopped if no improvement. + batch_size (int): + The number of tokens in each batch. Default: 5000. + update_steps (int): + Gradient accumulation steps. Default: 1. + buckets (int): + The number of buckets that sentences are assigned to. Default: 32. + workers (int): + The number of subprocesses used for data loading. 0 means only the main process. Default: 0. + clip (float): + Clips gradient of an iterable of parameters at specified value. Default: 5.0. + amp (bool): + Specifies whether to use automatic mixed precision. Default: ``False``. + cache (bool): + If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``. + verbose (bool): + If ``True``, increases the output verbosity. Default: ``True``. + """ + + args = self.args.update(locals()) + init_logger(logger, verbose=args.verbose) + + self.transform.train() + batch_size = batch_size // update_steps + eval_batch_size = args.get('eval_batch_size', batch_size) + if is_dist(): + batch_size = batch_size // dist.get_world_size() + eval_batch_size = eval_batch_size // dist.get_world_size() + logger.info("Loading the data") + if args.cache: + args.bin = os.path.join(os.path.dirname(args.path), 'bin') + args.even = args.get('even', is_dist()) + + train = Dataset(self.transform, args.train, **args).build(batch_size=batch_size, + n_buckets=buckets, + shuffle=True, + distributed=is_dist(), + even=args.even, + n_workers=workers) + dev = Dataset(self.transform, args.dev, **args).build(batch_size=eval_batch_size, + n_buckets=buckets, + shuffle=False, + distributed=is_dist(), + even=False, + n_workers=workers) + logger.info(f"{'train:':6} {train}") + if not args.test: + logger.info(f"{'dev:':6} {dev}\n") + else: + test = Dataset(self.transform, args.test, **args).build(batch_size=eval_batch_size, + n_buckets=buckets, + shuffle=False, + distributed=is_dist(), + even=False, + n_workers=workers) + logger.info(f"{'dev:':6} {dev}") + logger.info(f"{'test:':6} {test}\n") + loader, sampler = train.loader, train.loader.batch_sampler + args.steps = len(loader) * epochs // args.update_steps + args.save(f"{args.path}.yaml") + + self.optimizer = self.init_optimizer() + self.scheduler = self.init_scheduler() + self.scaler = GradScaler(enabled=args.amp) + + if dist.is_initialized(): + self.model = DDP(module=self.model, + device_ids=[args.local_rank], + find_unused_parameters=args.get('find_unused_parameters', True), + static_graph=args.get('static_graph', False)) + if args.amp: + from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook + self.model.register_comm_hook(dist.group.WORLD, fp16_compress_hook) + if args.wandb and is_master(): + import wandb + # start a new wandb run to track this script + wandb.init(config=args.primitive_config, + project=args.get('project', self.NAME), + name=args.get('name', args.path), + resume=self.args.checkpoint) + self.step, self.epoch, self.best_e, self.patience = 1, 1, 1, patience + # uneven batches are excluded + self.n_batches = min(gather(len(loader))) if is_dist() else len(loader) + self.best_metric, self.elapsed = Metric(), timedelta() + if args.checkpoint: + try: + self.optimizer.load_state_dict(self.checkpoint_state_dict.pop('optimizer_state_dict')) + self.scheduler.load_state_dict(self.checkpoint_state_dict.pop('scheduler_state_dict')) + self.scaler.load_state_dict(self.checkpoint_state_dict.pop('scaler_state_dict')) + set_rng_state(self.checkpoint_state_dict.pop('rng_state')) + for k, v in self.checkpoint_state_dict.items(): + setattr(self, k, v) + sampler.set_epoch(self.epoch) + except AttributeError: + logger.warning("No checkpoint found. Try re-launching the training procedure instead") + + for epoch in range(self.epoch, args.epochs + 1): + start = datetime.now() + bar, metric = progress_bar(loader), Metric() + + logger.info(f"Epoch {epoch} / {args.epochs}:") + self.model.train() + with self.join(): + # we should reset `step` as the number of batches in different processes is not necessarily equal + self.step = 1 + for batch in bar: + with self.sync(): + with torch.autocast(self.device, enabled=args.amp): + loss = self.train_step(batch) + self.backward(loss) + if self.sync_grad: + self.clip_grad_norm_(self.model.parameters(), args.clip) + self.scaler.step(self.optimizer) + self.scaler.update() + self.scheduler.step() + self.optimizer.zero_grad(True) + + bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f}") + # log metrics to wandb + if args.wandb and is_master(): + wandb.log({'lr': self.scheduler.get_last_lr()[0], 'loss': loss}) + self.step += 1 + logger.info(f"{bar.postfix}") + self.model.eval() + with self.join(), torch.autocast(self.device, enabled=args.amp): + metric = self.reduce(sum([self.eval_step(i) for i in progress_bar(dev.loader)], Metric())) + logger.info(f"{'dev:':5} {metric}") + if args.wandb and is_master(): + wandb.log({'dev': metric.values, 'epochs': epoch}) + if args.test: + test_metric = sum([self.eval_step(i) for i in progress_bar(test.loader)], Metric()) + logger.info(f"{'test:':5} {self.reduce(test_metric)}") + if args.wandb and is_master(): + wandb.log({'test': test_metric.values, 'epochs': epoch}) + + t = datetime.now() - start + self.epoch += 1 + self.patience -= 1 + self.elapsed += t + + if metric > self.best_metric: + self.best_e, self.patience, self.best_metric = epoch, patience, metric + if is_master(): + self.save_checkpoint(args.path) + logger.info(f"{t}s elapsed (saved)\n") + else: + logger.info(f"{t}s elapsed\n") + if self.patience < 1: + break + if is_dist(): + dist.barrier() + + best = self.load(**args) + # only allow the master device to save models + if is_master(): + best.save(args.path) + + logger.info(f"Epoch {self.best_e} saved") + logger.info(f"{'dev:':5} {self.best_metric}") + if args.test: + best.model.eval() + with best.join(): + test_metric = sum([best.eval_step(i) for i in progress_bar(test.loader)], Metric()) + logger.info(f"{'test:':5} {best.reduce(test_metric)}") + logger.info(f"{self.elapsed}s elapsed, {self.elapsed / epoch}s/epoch") + if args.wandb and is_master(): + wandb.finish() + + self.evaluate(data=args.test, batch_size=batch_size) + self.predict(args.test, batch_size=batch_size, buckets=buckets, workers=workers) + + with open(f'{self.args.folder}/status', 'w') as file: + file.write('finished') + + + + def evaluate( + self, + data: Union[str, Iterable], + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + amp: bool = False, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + r""" + Args: + data (Union[str, Iterable]): + The data for evaluation. Both a filename and a list of instances are allowed. + batch_size (int): + The number of tokens in each batch. Default: 5000. + buckets (int): + The number of buckets that sentences are assigned to. Default: 8. + workers (int): + The number of subprocesses used for data loading. 0 means only the main process. Default: 0. + amp (bool): + Specifies whether to use automatic mixed precision. Default: ``False``. + cache (bool): + If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``. + verbose (bool): + If ``True``, increases the output verbosity. Default: ``True``. + + Returns: + The evaluation results. + """ + + args = self.args.update(locals()) + init_logger(logger, verbose=args.verbose) + + self.transform.train() + logger.info("Loading the data") + if args.cache: + args.bin = os.path.join(os.path.dirname(args.path), 'bin') + if is_dist(): + batch_size = batch_size // dist.get_world_size() + data = Dataset(self.transform, **args) + data.build(batch_size=batch_size, + n_buckets=buckets, + shuffle=False, + distributed=is_dist(), + even=False, + n_workers=workers) + logger.info(f"\n{data}") + + logger.info("Evaluating the data") + start = datetime.now() + self.model.eval() + with self.join(): + bar, metric = progress_bar(data.loader), Metric() + for batch in bar: + metric += self.eval_step(batch) + bar.set_postfix_str(metric) + metric = self.reduce(metric) + elapsed = datetime.now() - start + logger.info(f"{metric}") + logger.info(f"{elapsed}s elapsed, " + f"{sum(data.sizes)/elapsed.total_seconds():.2f} Tokens/s, " + f"{len(data)/elapsed.total_seconds():.2f} Sents/s") + os.makedirs(os.path.dirname(self.args.folder + '/metrics.pickle'), exist_ok=True) + with open(f'{self.args.folder}/metrics.pickle', 'wb') as file: + pickle.dump(obj=metric, file=file) + + return metric + + def predict( + self, + data: Union[str, Iterable], + pred: str = None, + lang: str = None, + prob: bool = False, + batch_size: int = 5000, + buckets: int = 8, + workers: int = 0, + cache: bool = False, + verbose: bool = True, + **kwargs + ): + r""" + Args: + data (Union[str, Iterable]): + The data for prediction. + - a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts. + - a list of instances. + pred (str): + If specified, the predicted results will be saved to the file. Default: ``None``. + lang (str): + Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. + ``None`` if tokenization is not required. + Default: ``None``. + prob (bool): + If ``True``, outputs the probabilities. Default: ``False``. + batch_size (int): + The number of tokens in each batch. Default: 5000. + buckets (int): + The number of buckets that sentences are assigned to. Default: 8. + workers (int): + The number of subprocesses used for data loading. 0 means only the main process. Default: 0. + amp (bool): + Specifies whether to use automatic mixed precision. Default: ``False``. + cache (bool): + If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``. + verbose (bool): + If ``True``, increases the output verbosity. Default: ``True``. + + Returns: + A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``. + """ + + args = self.args.update(locals()) + init_logger(logger, verbose=args.verbose) + + if self.args.use_vq: + self.model.passes_remaining = 0 + self.model.vq.observe_steps_remaining = 0 + + self.transform.eval() + if args.prob: + self.transform.append(Field('probs')) + + #logger.info("Loading the data") + if args.cache: + args.bin = os.path.join(os.path.dirname(args.path), 'bin') + if is_dist(): + batch_size = batch_size // dist.get_world_size() + data = Dataset(self.transform, **args) + data.build(batch_size=batch_size, + n_buckets=buckets, + shuffle=False, + distributed=is_dist(), + even=False, + n_workers=workers) + + #logger.info(f"\n{data}") + + #logger.info("Making predictions on the data") + start = datetime.now() + self.model.eval() + #with tempfile.TemporaryDirectory() as t: + # we have clustered the sentences by length here to speed up prediction, + # so the order of the yielded sentences can't be guaranteed + for batch in progress_bar(data.loader): + #batch, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text, act_dict = self.pred_step(batch) + *predicted_values, = self.pred_step(batch) + #print('429 supar/parser.py ', batch.sentences) + #print(head_preds, deprel_preds, stack_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text) + #logger.info(f"Saving predicted results to {pred}") + # with open(pred, 'w') as f: + + + elapsed = datetime.now() - start + + #if is_dist(): + # dist.barrier() + #tdirs = gather(t) if is_dist() else (t,) + #if pred is not None and is_master(): + #logger.info(f"Saving predicted results to {pred}") + """with open(pred, 'w') as f: + # merge all predictions into one single file + if is_dist() or args.cache: + sentences = (os.path.join(i, s) for i in tdirs for s in os.listdir(i)) + for i in progress_bar(sorted(sentences, key=lambda x: int(os.path.basename(x)))): + with open(i) as s: + shutil.copyfileobj(s, f) + else: + for s in progress_bar(data): + f.write(str(s) + '\n')""" + # exit util all files have been merged + if is_dist(): + dist.barrier() + #logger.info(f"{elapsed}s elapsed, " + # f"{sum(data.sizes)/elapsed.total_seconds():.2f} Tokens/s, " + # f"{len(data)/elapsed.total_seconds():.2f} Sents/s") + + if not cache: + #return data, head_preds, deprel_preds, stack_list, buffer_list, actions_list, act_dict_list, deprel_preds_decoded, pos_preds_decoded, sent_text, act_dict + return *predicted_values, + + def backward(self, loss: torch.Tensor, **kwargs): + loss /= self.args.update_steps + if hasattr(self, 'scaler'): + self.scaler.scale(loss).backward(**kwargs) + else: + loss.backward(**kwargs) + + def clip_grad_norm_( + self, + params: Union[Iterable[torch.Tensor], torch.Tensor], + max_norm: float, + norm_type: float = 2 + ) -> torch.Tensor: + self.scaler.unscale_(self.optimizer) + return nn.utils.clip_grad_norm_(params, max_norm, norm_type) + + def clip_grad_value_( + self, + params: Union[Iterable[torch.Tensor], torch.Tensor], + clip_value: float + ) -> None: + self.scaler.unscale_(self.optimizer) + return nn.utils.clip_grad_value_(params, clip_value) + + def reduce(self, obj: Any) -> Any: + if not is_dist(): + return obj + return reduce(obj) + + def train_step(self, batch: Batch) -> torch.Tensor: + ... + + @torch.no_grad() + def eval_step(self, batch: Batch) -> Metric: + ... + + @torch.no_grad() + def pred_step(self, batch: Batch) -> Batch: + ... + + def init_optimizer(self) -> Optimizer: + if self.args.encoder in ('lstm', 'transformer'): + optimizer = Adam(params=self.model.parameters(), + lr=self.args.lr, + betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)), + eps=self.args.get('eps', 1e-8), + weight_decay=self.args.get('weight_decay', 0)) + else: + # we found that Huggingface's AdamW is more robust and empirically better than the native implementation + from transformers import AdamW + optimizer = AdamW(params=[{'params': p, 'lr': self.args.lr * (1 if n.startswith('encoder') else self.args.lr_rate)} + for n, p in self.model.named_parameters()], + lr=self.args.lr, + betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)), + eps=self.args.get('eps', 1e-8), + weight_decay=self.args.get('weight_decay', 0)) + return optimizer + + def init_scheduler(self) -> _LRScheduler: + if self.args.encoder == 'lstm': + scheduler = ExponentialLR(optimizer=self.optimizer, + gamma=self.args.decay**(1/self.args.decay_steps)) + elif self.args.encoder == 'transformer': + scheduler = InverseSquareRootLR(optimizer=self.optimizer, + warmup_steps=self.args.warmup_steps) + else: + scheduler = LinearLR(optimizer=self.optimizer, + warmup_steps=self.args.get('warmup_steps', int(self.args.steps*self.args.get('warmup', 0))), + steps=self.args.steps) + return scheduler + + @classmethod + def build(cls, path, **kwargs): + ... + + @classmethod + def load( + cls, + path: str, + reload: bool = True, + src: str = 'github', + checkpoint: bool = False, + **kwargs + ) -> Parser: + r""" + Loads a parser with data fields and pretrained model parameters. + + Args: + path (str): + - a string with the shortcut name of a pretrained model defined in ``supar.MODEL`` + to load from cache or download, e.g., ``'biaffine-dep-en'``. + - a local path to a pretrained model, e.g., ``./<path>/model``. + reload (bool): + Whether to discard the existing cache and force a fresh download. Default: ``False``. + src (str): + Specifies where to download the model. + ``'github'``: github release page. + ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8). + Default: ``'github'``. + checkpoint (bool): + If ``True``, loads all checkpoint states to restore the training process. Default: ``False``. + + Examples: + >>> from supar import Parser + >>> parser = Parser.load('biaffine-dep-en') + >>> parser = Parser.load('./ptb.biaffine.dep.lstm.char') + """ + + args = Config(**locals()) + if not os.path.exists(path): + path = download(supar.MODEL[src].get(path, path), reload=reload) + state = torch.load(path, map_location='cpu', weights_only=False) + #torch.load(path, map_location='cpu') + cls = supar.PARSER[state['name']] if cls.NAME is None else cls + args = state['args'].update(args) + #print('ARGS', args) + model = cls.MODEL(**args) + model.load_pretrained(state['pretrained']) + model.load_state_dict(state['state_dict'], True) + transform = state['transform'] + parser = cls(args, model, transform) + parser.checkpoint_state_dict = state.get('checkpoint_state_dict', None) if checkpoint else None + parser.model.to(parser.device) + return parser + + def save(self, path: str) -> None: + model = self.model + if hasattr(model, 'module'): + model = self.model.module + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + pretrained = state_dict.pop('pretrained.weight', None) + state = {'name': self.NAME, + 'args': model.args, + 'state_dict': state_dict, + 'pretrained': pretrained, + 'transform': self.transform} + torch.save(state, path, pickle_module=dill) + + def save_checkpoint(self, path: str) -> None: + model = self.model + if hasattr(model, 'module'): + model = self.model.module + checkpoint_state_dict = {k: getattr(self, k) for k in ['epoch', 'best_e', 'patience', 'best_metric', 'elapsed']} + checkpoint_state_dict.update({'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'scaler_state_dict': self.scaler.state_dict(), + 'rng_state': get_rng_state()}) + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + pretrained = state_dict.pop('pretrained.weight', None) + state = {'name': self.NAME, + 'args': model.args, + 'state_dict': state_dict, + 'pretrained': pretrained, + 'checkpoint_state_dict': checkpoint_state_dict, + 'transform': self.transform} + torch.save(state, path, pickle_module=dill) diff --git a/tania_scripts/supar/structs/__init__.py b/tania_scripts/supar/structs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9afbd79927a0bb15b198b08ae53a18af03e42b7f --- /dev/null +++ b/tania_scripts/supar/structs/__init__.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + +from .chain import LinearChainCRF, SemiMarkovCRF +from .dist import StructuredDistribution +from .tree import (BiLexicalizedConstituencyCRF, ConstituencyCRF, + Dependency2oCRF, DependencyCRF, MatrixTree) +from .vi import (ConstituencyLBP, ConstituencyMFVI, DependencyLBP, + DependencyMFVI, SemanticDependencyLBP, SemanticDependencyMFVI) + +__all__ = ['StructuredDistribution', + 'LinearChainCRF', + 'SemiMarkovCRF', + 'MatrixTree', + 'DependencyCRF', + 'Dependency2oCRF', + 'ConstituencyCRF', + 'BiLexicalizedConstituencyCRF', + 'DependencyMFVI', + 'DependencyLBP', + 'ConstituencyMFVI', + 'ConstituencyLBP', + 'SemanticDependencyMFVI', + 'SemanticDependencyLBP', ] diff --git a/tania_scripts/supar/structs/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef548e93f5fe211cae6a010de99dbcb04c3710ad Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4517337ef5ef26d1d5c1d3e7001d4ea719d5120 Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/chain.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/chain.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e42d31e38ad624ca095da2a77618bca27d421e88 Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/chain.cpython-310.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/chain.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/chain.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f42a094bb248a728f5748ec2648f7828e587fdd Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/chain.cpython-311.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/dist.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/dist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fd9588165983a30bc91fedfe71a124d9e274b7a Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/dist.cpython-310.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/dist.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/dist.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bde8b84641e2db640d4bc8b6077df3a2e1aae03a Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/dist.cpython-311.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/fn.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..128838b69315f5cfbbab3bcd5933e1af0b4e9816 Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/fn.cpython-310.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/fn.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52542e3188bfcde5e56f78ad516e4e75e2da4552 Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/fn.cpython-311.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/semiring.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/semiring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eebbfcd87f3a9e230655a6f65e90217253551767 Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/semiring.cpython-310.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/semiring.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/semiring.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9cfb3a935ad9c1fee1faac17b0a0e4317f6aa3d Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/semiring.cpython-311.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/tree.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/tree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cfa02dffd8263f04ff8556087083f71ef8bccee Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/tree.cpython-310.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/tree.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/tree.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b42efb7cbe328a1db42ec469766abe4c93796ba5 Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/tree.cpython-311.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/vi.cpython-310.pyc b/tania_scripts/supar/structs/__pycache__/vi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dc6ac96670b2c3dda14334e09a50e31c53a3ff5 Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/vi.cpython-310.pyc differ diff --git a/tania_scripts/supar/structs/__pycache__/vi.cpython-311.pyc b/tania_scripts/supar/structs/__pycache__/vi.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b5a62221a93d1af40f302d54c6da4bfff1818d4 Binary files /dev/null and b/tania_scripts/supar/structs/__pycache__/vi.cpython-311.pyc differ diff --git a/tania_scripts/supar/structs/chain.py b/tania_scripts/supar/structs/chain.py new file mode 100644 index 0000000000000000000000000000000000000000..828caab76fac936df414ee46118050525933405f --- /dev/null +++ b/tania_scripts/supar/structs/chain.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import List, Optional + +import torch +from supar.structs.dist import StructuredDistribution +from supar.structs.semiring import LogSemiring, Semiring +from torch.distributions.utils import lazy_property + + +class LinearChainCRF(StructuredDistribution): + r""" + Linear-chain CRFs :cite:`lafferty-etal-2001-crf`. + + Args: + scores (~torch.Tensor): ``[batch_size, seq_len, n_tags]``. + Log potentials. + trans (~torch.Tensor): ``[n_tags+1, n_tags+1]``. + Transition scores. + ``trans[-1, :-1]``/``trans[:-1, -1]`` represent oracle for start/end positions respectively. + lens (~torch.LongTensor): ``[batch_size]``. + Sentence lengths for masking. Default: ``None``. + + Examples: + >>> from supar import LinearChainCRF + >>> batch_size, seq_len, n_tags = 2, 5, 4 + >>> lens = torch.tensor([3, 4]) + >>> value = torch.randint(n_tags, (batch_size, seq_len)) + >>> s1 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), + torch.randn(n_tags+1, n_tags+1), + lens) + >>> s2 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), + torch.randn(n_tags+1, n_tags+1), + lens) + >>> s1.max + tensor([4.4120, 8.9672], grad_fn=<MaxBackward0>) + >>> s1.argmax + tensor([[2, 0, 3, 0, 0], + [3, 3, 3, 2, 0]]) + >>> s1.log_partition + tensor([ 6.3486, 10.9106], grad_fn=<LogsumexpBackward>) + >>> s1.log_prob(value) + tensor([ -8.1515, -10.5572], grad_fn=<SubBackward0>) + >>> s1.entropy + tensor([3.4150, 3.6549], grad_fn=<SelectBackward>) + >>> s1.kl(s2) + tensor([4.0333, 4.3807], grad_fn=<SelectBackward>) + """ + + def __init__( + self, + scores: torch.Tensor, + trans: Optional[torch.Tensor] = None, + lens: Optional[torch.LongTensor] = None + ) -> LinearChainCRF: + super().__init__(scores, lens=lens) + + batch_size, seq_len, self.n_tags = scores.shape[:3] + self.lens = scores.new_full((batch_size,), seq_len).long() if lens is None else lens + self.mask = self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(seq_len))) + + self.trans = self.scores.new_full((self.n_tags+1, self.n_tags+1), LogSemiring.one) if trans is None else trans + + def __repr__(self): + return f"{self.__class__.__name__}(n_tags={self.n_tags})" + + def __add__(self, other): + return LinearChainCRF(torch.stack((self.scores, other.scores), -1), + torch.stack((self.trans, other.trans), -1), + self.lens) + + @lazy_property + def argmax(self): + return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, torch.where(self.backward(self.max.sum()))[2]) + + def topk(self, k: int) -> torch.LongTensor: + preds = torch.stack([torch.where(self.backward(i))[2] for i in self.kmax(k).sum(0)], -1) + return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds) + + def score(self, value: torch.LongTensor) -> torch.Tensor: + scores, mask, value = self.scores.transpose(0, 1), self.mask.t(), value.t() + prev, succ = torch.cat((torch.full_like(value[:1], -1), value[:-1]), 0), value + # [seq_len, batch_size] + alpha = scores.gather(-1, value.unsqueeze(-1)).squeeze(-1) + # [batch_size] + alpha = LogSemiring.prod(LogSemiring.one_mask(LogSemiring.mul(alpha, self.trans[prev, succ]), ~mask), 0) + alpha = alpha + self.trans[value.gather(0, self.lens.unsqueeze(0) - 1).squeeze(0), torch.full_like(value[0], -1)] + return alpha + + def forward(self, semiring: Semiring) -> torch.Tensor: + # [seq_len, batch_size, n_tags, ...] + scores = semiring.convert(self.scores.transpose(0, 1)) + trans = semiring.convert(self.trans) + mask = self.mask.t() + + # [batch_size, n_tags] + alpha = semiring.mul(trans[-1, :-1], scores[0]) + for i in range(1, len(mask)): + alpha[mask[i]] = semiring.mul(semiring.dot(alpha.unsqueeze(2), trans[:-1, :-1], 1), scores[i])[mask[i]] + alpha = semiring.dot(alpha, trans[:-1, -1], 1) + return semiring.unconvert(alpha) + + +class SemiMarkovCRF(StructuredDistribution): + r""" + Semi-markov CRFs :cite:`sarawagi-cohen-2004-semicrf`. + + Args: + scores (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_tags]``. + Log potentials. + trans (~torch.Tensor): ``[n_tags, n_tags]``. + Transition scores. + lens (~torch.LongTensor): ``[batch_size]``. + Sentence lengths for masking. Default: ``None``. + + Examples: + >>> from supar import SemiMarkovCRF + >>> batch_size, seq_len, n_tags = 2, 5, 4 + >>> lens = torch.tensor([3, 4]) + >>> value = torch.tensor([[[ 0, -1, -1, -1, -1], + [-1, -1, 2, -1, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1]], + [[-1, 1, -1, -1, -1], + [-1, -1, 3, -1, -1], + [-1, -1, -1, 0, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1]]]) + >>> s1 = SemiMarkovCRF(torch.randn(batch_size, seq_len, seq_len, n_tags), + torch.randn(n_tags, n_tags), + lens) + >>> s2 = SemiMarkovCRF(torch.randn(batch_size, seq_len, seq_len, n_tags), + torch.randn(n_tags, n_tags), + lens) + >>> s1.max + tensor([4.1971, 5.5746], grad_fn=<MaxBackward0>) + >>> s1.argmax + [[[0, 0, 1], [1, 1, 0], [2, 2, 1]], [[0, 0, 1], [1, 1, 3], [2, 2, 0], [3, 3, 1]]] + >>> s1.log_partition + tensor([6.3641, 8.4384], grad_fn=<LogsumexpBackward0>) + >>> s1.log_prob(value) + tensor([-5.7982, -7.4534], grad_fn=<SubBackward0>) + >>> s1.entropy + tensor([3.7520, 5.1609], grad_fn=<SelectBackward0>) + >>> s1.kl(s2) + tensor([3.5348, 2.2826], grad_fn=<SelectBackward0>) + """ + + def __init__( + self, + scores: torch.Tensor, + trans: Optional[torch.Tensor] = None, + lens: Optional[torch.LongTensor] = None + ) -> SemiMarkovCRF: + super().__init__(scores, lens=lens) + + batch_size, seq_len, _, self.n_tags = scores.shape[:4] + self.lens = scores.new_full((batch_size,), seq_len).long() if lens is None else lens + self.mask = self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(seq_len))) + self.mask = self.mask.unsqueeze(1) & self.mask.unsqueeze(2) + + self.trans = self.scores.new_full((self.n_tags, self.n_tags), LogSemiring.one) if trans is None else trans + + def __repr__(self): + return f"{self.__class__.__name__}(n_tags={self.n_tags})" + + def __add__(self, other): + return SemiMarkovCRF(torch.stack((self.scores, other.scores), -1), + torch.stack((self.trans, other.trans), -1), + self.lens) + + @lazy_property + def argmax(self) -> List: + return [torch.nonzero(i).tolist() for i in self.backward(self.max.sum())] + + def topk(self, k: int) -> List: + return list(zip(*[[torch.nonzero(j).tolist() for j in self.backward(i)] for i in self.kmax(k).sum(0)])) + + def score(self, value: torch.LongTensor) -> torch.Tensor: + mask = self.mask & value.ge(0) + lens = mask.sum((1, 2)) + indices = torch.where(mask) + batch_size, seq_len = lens.shape[0], lens.max() + span_mask = lens.unsqueeze(-1).gt(lens.new_tensor(range(seq_len))) + scores = self.scores.new_full((batch_size, seq_len), LogSemiring.one) + scores = scores.masked_scatter_(span_mask, self.scores[(*indices, value[indices])]) + scores = LogSemiring.prod(LogSemiring.one_mask(scores, ~span_mask), -1) + value = value.new_zeros(batch_size, seq_len).masked_scatter_(span_mask, value[indices]) + trans = LogSemiring.prod(LogSemiring.one_mask(self.trans[value[:, :-1], value[:, 1:]], ~span_mask[:, 1:]), -1) + return LogSemiring.mul(scores, trans) + + def forward(self, semiring: Semiring) -> torch.Tensor: + # [seq_len, seq_len, batch_size, n_tags, ...] + scores = semiring.convert(self.scores.movedim((1, 2), (0, 1))) + trans = semiring.convert(self.trans) + # [seq_len, batch_size, n_tags, ...] + alpha = semiring.zeros_like(scores[0]) + + alpha[0] = scores[0, 0] + # [batch_size, n_tags] + for t in range(1, len(scores)): + # [batch_size, n_tags, ...] + s = semiring.dot(semiring.dot(alpha[:t].unsqueeze(3), trans, 2), scores[1:t+1, t], 0) + alpha[t] = semiring.sum(torch.stack((s, scores[0, t])), 0) + return semiring.unconvert(semiring.sum(alpha[self.lens - 1, range(len(self.lens))], 1)) diff --git a/tania_scripts/supar/structs/dist.py b/tania_scripts/supar/structs/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab594d4e4c72374122e66672529394f656d788e --- /dev/null +++ b/tania_scripts/supar/structs/dist.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Iterable, Union + +import torch +import torch.autograd as autograd +from supar.structs.semiring import (CrossEntropySemiring, EntropySemiring, + KLDivergenceSemiring, KMaxSemiring, + LogSemiring, MaxSemiring, SampledSemiring, + Semiring) +from torch.distributions.distribution import Distribution +from torch.distributions.utils import lazy_property + + +class StructuredDistribution(Distribution): + r""" + Base class for structured distribution :math:`p(y)` :cite:`eisner-2016-inside,goodman-1999-semiring,li-eisner-2009-first`. + + Args: + scores (torch.Tensor): + Log potentials, also for high-order cases. + + """ + + def __init__(self, scores: torch.Tensor, **kwargs) -> StructuredDistribution: + self.scores = scores.requires_grad_() if isinstance(scores, torch.Tensor) else [s.requires_grad_() for s in scores] + self.kwargs = kwargs + + def __repr__(self): + return f"{self.__class__.__name__}()" + + def __add__(self, other: 'StructuredDistribution') -> StructuredDistribution: + return self.__class__(torch.stack((self.scores, other.scores), -1), lens=self.lens) + + @lazy_property + def log_partition(self): + r""" + Computes the log partition function of the distribution :math:`p(y)`. + """ + + return self.forward(LogSemiring) + + @lazy_property + def marginals(self): + r""" + Computes marginal probabilities of the distribution :math:`p(y)`. + """ + + return self.backward(self.log_partition.sum()) + + @lazy_property + def max(self): + r""" + Computes the max score of the distribution :math:`p(y)`. + """ + + return self.forward(MaxSemiring) + + @lazy_property + def argmax(self): + r""" + Computes :math:`\arg\max_y p(y)` of the distribution :math:`p(y)`. + """ + + return self.backward(self.max.sum()) + + @lazy_property + def mode(self): + return self.argmax + + def kmax(self, k: int) -> torch.Tensor: + r""" + Computes the k-max of the distribution :math:`p(y)`. + """ + + return self.forward(KMaxSemiring(k)) + + def topk(self, k: int) -> Union[torch.Tensor, Iterable]: + r""" + Computes the k-argmax of the distribution :math:`p(y)`. + """ + raise NotImplementedError + + def sample(self): + r""" + Obtains a structured sample from the distribution :math:`y \sim p(y)`. + TODO: multi-sampling. + """ + + return self.backward(self.forward(SampledSemiring).sum()).detach() + + @lazy_property + def entropy(self): + r""" + Computes entropy :math:`H[p]` of the distribution :math:`p(y)`. + """ + + return self.forward(EntropySemiring) + + def cross_entropy(self, other: 'StructuredDistribution') -> torch.Tensor: + r""" + Computes cross-entropy :math:`H[p,q]` of self and another distribution. + + Args: + other (~supar.structs.dist.StructuredDistribution): Comparison distribution. + """ + + return (self + other).forward(CrossEntropySemiring) + + def kl(self, other: 'StructuredDistribution') -> torch.Tensor: + r""" + Computes KL-divergence :math:`KL[p \parallel q]=H[p,q]-H[p]` of self and another distribution. + + Args: + other (~supar.structs.dist.StructuredDistribution): Comparison distribution. + """ + + return (self + other).forward(KLDivergenceSemiring) + + def log_prob(self, value: torch.LongTensor, *args, **kwargs) -> torch.Tensor: + """ + Computes log probability over values :math:`p(y)`. + """ + + return self.score(value, *args, **kwargs) - self.log_partition + + def score(self, value: torch.LongTensor, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError + + @torch.enable_grad() + def forward(self, semiring: Semiring) -> torch.Tensor: + raise NotImplementedError + + def backward(self, log_partition: torch.Tensor) -> Union[torch.Tensor, Iterable[torch.Tensor]]: + grads = autograd.grad(log_partition, self.scores, create_graph=True) + return grads[0] if isinstance(self.scores, torch.Tensor) else grads diff --git a/tania_scripts/supar/structs/fn.py b/tania_scripts/supar/structs/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0af2ad71380169b1b30d27e545bf332a281e84 --- /dev/null +++ b/tania_scripts/supar/structs/fn.py @@ -0,0 +1,379 @@ +# -*- coding: utf-8 -*- + +import operator +from typing import Iterable, Tuple, Union + +import torch +from supar.utils.common import INF, MIN +from supar.utils.fn import pad +from torch.autograd import Function + + +def tarjan(sequence: Iterable[int]) -> Iterable[int]: + r""" + Tarjan algorithm for finding Strongly Connected Components (SCCs) of a graph. + + Args: + sequence (list): + List of head indices. + + Yields: + A list of indices making up a SCC. All self-loops are ignored. + + Examples: + >>> next(tarjan([2, 5, 0, 3, 1])) # (1 -> 5 -> 2 -> 1) is a cycle + [2, 5, 1] + """ + + sequence = [-1] + sequence + # record the search order, i.e., the timestep + dfn = [-1] * len(sequence) + # record the the smallest timestep in a SCC + low = [-1] * len(sequence) + # push the visited into the stack + stack, onstack = [], [False] * len(sequence) + + def connect(i, timestep): + dfn[i] = low[i] = timestep[0] + timestep[0] += 1 + stack.append(i) + onstack[i] = True + + for j, head in enumerate(sequence): + if head != i: + continue + if dfn[j] == -1: + yield from connect(j, timestep) + low[i] = min(low[i], low[j]) + elif onstack[j]: + low[i] = min(low[i], dfn[j]) + + # a SCC is completed + if low[i] == dfn[i]: + cycle = [stack.pop()] + while cycle[-1] != i: + onstack[cycle[-1]] = False + cycle.append(stack.pop()) + onstack[i] = False + # ignore the self-loop + if len(cycle) > 1: + yield cycle + + timestep = [0] + for i in range(len(sequence)): + if dfn[i] == -1: + yield from connect(i, timestep) + + +def chuliu_edmonds(s: torch.Tensor) -> torch.Tensor: + r""" + ChuLiu/Edmonds algorithm for non-projective decoding :cite:`mcdonald-etal-2005-non`. + + Some code is borrowed from `tdozat's implementation`_. + Descriptions of notations and formulas can be found in :cite:`mcdonald-etal-2005-non`. + + Notes: + The algorithm does not guarantee to parse a single-root tree. + + Args: + s (~torch.Tensor): ``[seq_len, seq_len]``. + Scores of all dependent-head pairs. + + Returns: + ~torch.Tensor: + A tensor with shape ``[seq_len]`` for the resulting non-projective parse tree. + + .. _tdozat's implementation: + https://github.com/tdozat/Parser-v3 + """ + + s[0, 1:] = MIN + # prevent self-loops + s.diagonal()[1:].fill_(MIN) + # select heads with highest scores + tree = s.argmax(-1) + # return the cycle finded by tarjan algorithm lazily + cycle = next(tarjan(tree.tolist()[1:]), None) + # if the tree has no cycles, then it is a MST + if not cycle: + return tree + # indices of cycle in the original tree + cycle = torch.tensor(cycle) + # indices of noncycle in the original tree + noncycle = torch.ones(len(s)).index_fill_(0, cycle, 0) + noncycle = torch.where(noncycle.gt(0))[0] + + def contract(s): + # heads of cycle in original tree + cycle_heads = tree[cycle] + # scores of cycle in original tree + s_cycle = s[cycle, cycle_heads] + + # calculate the scores of cycle's potential dependents + # s(c->x) = max(s(x'->x)), x in noncycle and x' in cycle + s_dep = s[noncycle][:, cycle] + # find the best cycle head for each noncycle dependent + deps = s_dep.argmax(1) + # calculate the scores of cycle's potential heads + # s(x->c) = max(s(x'->x) - s(a(x')->x') + s(cycle)), x in noncycle and x' in cycle + # a(v) is the predecessor of v in cycle + # s(cycle) = sum(s(a(v)->v)) + s_head = s[cycle][:, noncycle] - s_cycle.view(-1, 1) + s_cycle.sum() + # find the best noncycle head for each cycle dependent + heads = s_head.argmax(0) + + contracted = torch.cat((noncycle, torch.tensor([-1]))) + # calculate the scores of contracted graph + s = s[contracted][:, contracted] + # set the contracted graph scores of cycle's potential dependents + s[:-1, -1] = s_dep[range(len(deps)), deps] + # set the contracted graph scores of cycle's potential heads + s[-1, :-1] = s_head[heads, range(len(heads))] + + return s, heads, deps + + # keep track of the endpoints of the edges into and out of cycle for reconstruction later + s, heads, deps = contract(s) + + # y is the contracted tree + y = chuliu_edmonds(s) + # exclude head of cycle from y + y, cycle_head = y[:-1], y[-1] + + # fix the subtree with no heads coming from the cycle + # len(y) denotes heads coming from the cycle + subtree = y < len(y) + # add the nodes to the new tree + tree[noncycle[subtree]] = noncycle[y[subtree]] + # fix the subtree with heads coming from the cycle + subtree = ~subtree + # add the nodes to the tree + tree[noncycle[subtree]] = cycle[deps[subtree]] + # fix the root of the cycle + cycle_root = heads[cycle_head] + # break the cycle and add the root of the cycle to the tree + tree[cycle[cycle_root]] = noncycle[cycle_head] + + return tree + + +def mst(scores: torch.Tensor, mask: torch.BoolTensor, multiroot: bool = False) -> torch.Tensor: + r""" + MST algorithm for decoding non-projective trees. + This is a wrapper for ChuLiu/Edmonds algorithm. + + The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots, + If ``multiroot=True`` and there indeed exist multi-roots, the algorithm seeks to find + best single-root trees by iterating all possible single-root trees parsed by ChuLiu/Edmonds. + Otherwise the resulting trees are directly taken as the final outputs. + + Args: + scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all dependent-head pairs. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask to avoid parsing over padding tokens. + The first column serving as pseudo words for roots should be ``False``. + multiroot (bool): + Ensures to parse a single-root tree If ``False``. + + Returns: + ~torch.Tensor: + A tensor with shape ``[batch_size, seq_len]`` for the resulting non-projective parse trees. + + Examples: + >>> scores = torch.tensor([[[-11.9436, -13.1464, -6.4789, -13.8917], + [-60.6957, -60.2866, -48.6457, -63.8125], + [-38.1747, -49.9296, -45.2733, -49.5571], + [-19.7504, -23.9066, -9.9139, -16.2088]]]) + >>> scores[:, 0, 1:] = MIN + >>> scores.diagonal(0, 1, 2)[1:].fill_(MIN) + >>> mask = torch.tensor([[False, True, True, True]]) + >>> mst(scores, mask) + tensor([[0, 2, 0, 2]]) + """ + + _, seq_len, _ = scores.shape + scores = scores.cpu().unbind() + + preds = [] + for i, length in enumerate(mask.sum(1).tolist()): + s = scores[i][:length+1, :length+1] + tree = chuliu_edmonds(s) + roots = torch.where(tree[1:].eq(0))[0] + 1 + if not multiroot and len(roots) > 1: + s_root = s[:, 0] + s_best = MIN + s = s.index_fill(1, torch.tensor(0), MIN) + for root in roots: + s[:, 0] = MIN + s[root, 0] = s_root[root] + t = chuliu_edmonds(s) + s_tree = s[1:].gather(1, t[1:].unsqueeze(-1)).sum() + if s_tree > s_best: + s_best, tree = s_tree, t + preds.append(tree) + + return pad(preds, total_length=seq_len).to(mask.device) + + +def levenshtein(x: Iterable, y: Iterable, align: bool = False) -> int: + """ + Calculates the Levenshtein edit-distance between two sequences. + The edit distance is the number of characters that need to be + substituted, inserted, or deleted, to transform `x` into `y`. + + For example, transforming "rain" to "shine" requires three steps, + consisting of two substitutions and one insertion: + "rain" -> "sain" -> "shin" -> "shine". + These operations could have been done in other orders, but at least three steps are needed. + + Allows specifying the cost of substitution edits (e.g., "a" -> "b"), + because sometimes it makes sense to assign greater penalties to substitutions. + + The code is revised from `nltk`_ and `wiki`_'s implementations. + + Args: + x/y (Iterable): + The sequences to be analysed. + align (bool): + Whether to return the alignments based on the minimum Levenshtein edit-distance. Default: ``False``. + + Examples: + >>> from supar.structs.utils.fn import levenshtein + >>> levenshtein('intention', 'execution', align=True) + (5, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9)]) + + .. _nltk: + https://github.com/nltk/nltk/blob/develop/nltk/metrics/distance.py + .. _wiki: + https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance + """ + + # set up a 2-D array + len1, len2 = len(x), len(y) + lev = [list(range(len2 + 1))] + [[i] + [0] * len2 for i in range(1, len1 + 1)] + + # iterate over the array + # i and j start from 1 and not 0 to stay close to the wikipedia pseudo-code + # see https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance + for i in range(1, len1 + 1): + for j in range(1, len2 + 1): + # substitution + s = lev[i - 1][j - 1] + (x[i - 1] != y[j - 1]) + # deletion + a = lev[i - 1][j] + 1 + # insertion + b = lev[i][j - 1] + 1 + + lev[i][j] = min(s, a, b) + distance = lev[-1][-1] + if align: + i, j = len1, len2 + alignments = [(i, j)] + while (i, j) != (0, 0): + directions = [ + (i - 1, j - 1), # substitution + (i - 1, j), # deletion + (i, j - 1), # insertion + ] + direction_costs = ((lev[i][j] if (i >= 0 and j >= 0) else INF, (i, j)) for i, j in directions) + _, (i, j) = min(direction_costs, key=operator.itemgetter(0)) + alignments.append((i, j)) + alignments = list(reversed(alignments)) + return (distance, alignments) if align else distance + + +class Logsumexp(Function): + + r""" + Safer ``logsumexp`` to cure unnecessary NaN values that arise from inf arguments. + See discussions at http://github.com/pytorch/pytorch/issues/49724. + To be optimized with C++/Cuda extensions. + """ + + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float) + def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + output = x.logsumexp(dim) + ctx.dim = dim + ctx.save_for_backward(x, output) + return output.clone() + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]: + x, output, dim = *ctx.saved_tensors, ctx.dim + g, output = g.unsqueeze(dim), output.unsqueeze(dim) + mask = g.eq(0).expand_as(x) + grad = g * (x - output).exp() + return torch.where(mask, x.new_tensor(0.), grad), None + + +class Logaddexp(Function): + + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float) + def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.logaddexp(x, y) + ctx.save_for_backward(x, y, output) + return output.clone() + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: + x, y, output = ctx.saved_tensors + mask = g.eq(0) + grad_x, grad_y = (x - output).exp(), (y - output).exp() + grad_x = torch.where(mask, x.new_tensor(0.), grad_x) + grad_y = torch.where(mask, y.new_tensor(0.), grad_y) + return grad_x, grad_y + + +class SampledLogsumexp(Function): + + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float) + def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + ctx.dim = dim + ctx.save_for_backward(x) + return x.logsumexp(dim=dim) + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]: + from torch.distributions import OneHotCategorical + (x, ), dim = ctx.saved_tensors, ctx.dim + return g.unsqueeze(dim).mul(OneHotCategorical(logits=x.movedim(dim, -1)).sample().movedim(-1, dim)), None + + +class Sparsemax(Function): + + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float) + def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + ctx.dim = dim + sorted_x, _ = x.sort(dim, True) + z = sorted_x.cumsum(dim) - 1 + k = x.new_tensor(range(1, sorted_x.size(dim) + 1)).view(-1, *[1] * (x.dim() - 1)).transpose(0, dim) + k = (k * sorted_x).gt(z).sum(dim, True) + tau = z.gather(dim, k - 1) / k + p = torch.clamp(x - tau, 0) + ctx.save_for_backward(k, p) + return p + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, g: torch.Tensor) -> Tuple[torch.Tensor, None]: + k, p, dim = *ctx.saved_tensors, ctx.dim + grad = g.masked_fill(p.eq(0), 0) + grad = torch.where(p.ne(0), grad - grad.sum(dim, True) / k, grad) + return grad, None + + +logsumexp = Logsumexp.apply + +logaddexp = Logaddexp.apply + +sampled_logsumexp = SampledLogsumexp.apply + +sparsemax = Sparsemax.apply diff --git a/tania_scripts/supar/structs/semiring.py b/tania_scripts/supar/structs/semiring.py new file mode 100644 index 0000000000000000000000000000000000000000..9b66beee4f2e681a12cb7329e26c8880020077af --- /dev/null +++ b/tania_scripts/supar/structs/semiring.py @@ -0,0 +1,406 @@ +# -*- coding: utf-8 -*- + +import itertools +from functools import reduce +from typing import Iterable + +import torch +from supar.structs.fn import sampled_logsumexp, sparsemax +from supar.utils.common import MIN + + +class Semiring(object): + r""" + Base semiring class :cite:`goodman-1999-semiring`. + + A semiring is defined by a tuple :math:`<K, \oplus, \otimes, \mathbf{0}, \mathbf{1}>`. + :math:`K` is a set of values; + :math:`\oplus` is commutative, associative and has an identity element `0`; + :math:`\otimes` is associative, has an identity element `1` and distributes over `+`. + """ + + zero = 0 + one = 1 + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + @classmethod + def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x * y + + @classmethod + def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.sum(dim) + + @classmethod + def prod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.prod(dim) + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.cumsum(dim) + + @classmethod + def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.cumprod(dim) + + @classmethod + def dot(cls, x: torch.Tensor, y: torch.Tensor, dim: int = -1) -> torch.Tensor: + return cls.sum(cls.mul(x, y), dim) + + @classmethod + def times(cls, *x: Iterable[torch.Tensor]) -> torch.Tensor: + return reduce(lambda i, j: cls.mul(i, j), x) + + @classmethod + def zero_(cls, x: torch.Tensor) -> torch.Tensor: + return x.fill_(cls.zero) + + @classmethod + def one_(cls, x: torch.Tensor) -> torch.Tensor: + return x.fill_(cls.one) + + @classmethod + def zero_mask(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + return x.masked_fill(mask, cls.zero) + + @classmethod + def zero_mask_(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + return x.masked_fill_(mask, cls.zero) + + @classmethod + def one_mask(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + return x.masked_fill(mask, cls.one) + + @classmethod + def one_mask_(cls, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + return x.masked_fill_(mask, cls.one) + + @classmethod + def zeros_like(cls, x: torch.Tensor) -> torch.Tensor: + return x.new_full(x.shape, cls.zero) + + @classmethod + def ones_like(cls, x: torch.Tensor) -> torch.Tensor: + return x.new_full(x.shape, cls.one) + + @classmethod + def convert(cls, x: torch.Tensor) -> torch.Tensor: + return x + + @classmethod + def unconvert(cls, x: torch.Tensor) -> torch.Tensor: + return x + + +class LogSemiring(Semiring): + r""" + Log-space semiring :math:`<\mathrm{logsumexp}, +, -\infty, 0>`. + """ + + zero = MIN + one = 0 + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x.logaddexp(y) + + @classmethod + def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + @classmethod + def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.logsumexp(dim) + + @classmethod + def prod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.sum(dim) + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.logcumsumexp(dim) + + @classmethod + def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.cumsum(dim) + + +class MaxSemiring(LogSemiring): + r""" + Max semiring :math:`<\mathrm{max}, +, -\infty, 0>`. + """ + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x.max(y) + + @classmethod + def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.max(dim)[0] + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.cummax(dim) + + +def KMaxSemiring(k): + r""" + k-max semiring :math:`<\mathrm{kmax}, +, [-\infty, -\infty, \dots], [0, -\infty, \dots]>`. + """ + + class KMaxSemiring(LogSemiring): + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x.unsqueeze(-1).max(y.unsqueeze(-2)).flatten(-2).topk(k, -1)[0] + + @classmethod + def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return (x.unsqueeze(-1) + y.unsqueeze(-2)).flatten(-2).topk(k, -1)[0] + + @classmethod + def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.movedim(dim, -1).flatten(-2).topk(k, -1)[0] + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim) + + @classmethod + def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim) + + @classmethod + def one_(cls, x: torch.Tensor) -> torch.Tensor: + x[..., :1].fill_(cls.one) + x[..., 1:].fill_(cls.zero) + return x + + @classmethod + def convert(cls, x: torch.Tensor) -> torch.Tensor: + return torch.cat((x.unsqueeze(-1), cls.zero_(x.new_empty(*x.shape, k - 1))), -1) + + return KMaxSemiring + + +class ExpectationSemiring(Semiring): + r""" + Expectation semiring :math:`<\oplus, +, [0, 0], [1, 0]>` :cite:`li-eisner-2009-first`. + + Practical Applications: :math:`H(p) = \log Z - \frac{1}{Z}\sum_{d \in D} p(d) r(d)`. + """ + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + @classmethod + def mul(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.stack((x[..., 0] * y[..., 0], x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]), -1) + + @classmethod + def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return x.sum(dim) + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim) + + @classmethod + def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim) + + @classmethod + def zero_(cls, x: torch.Tensor) -> torch.Tensor: + return x.fill_(cls.zero) + + @classmethod + def one_(cls, x: torch.Tensor) -> torch.Tensor: + x[..., 0].fill_(cls.one) + x[..., 1].fill_(cls.zero) + return x + + +class EntropySemiring(LogSemiring): + r""" + Entropy expectation semiring :math:`<\oplus, +, [-\infty, 0], [0, 0]>`, + where :math:`\oplus` computes the log-values and the running distributional entropy :math:`H[p]` + :cite:`li-eisner-2009-first,hwa-2000-sample,kim-etal-2019-unsupervised`. + """ + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return cls.sum(torch.stack((x, y)), 0) + + @classmethod + def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + p = x[..., 0].logsumexp(dim) + r = x[..., 0] - p.unsqueeze(dim) + r = r.exp().mul((x[..., 1] - r)).sum(dim) + return torch.stack((p, r), -1) + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim) + + @classmethod + def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim) + + @classmethod + def zero_(cls, x: torch.Tensor) -> torch.Tensor: + x[..., 0].fill_(cls.zero) + x[..., 1].fill_(cls.one) + return x + + @classmethod + def one_(cls, x: torch.Tensor) -> torch.Tensor: + return x.fill_(cls.one) + + @classmethod + def convert(cls, x: torch.Tensor) -> torch.Tensor: + return torch.stack((x, cls.ones_like(x)), -1) + + @classmethod + def unconvert(cls, x: torch.Tensor) -> torch.Tensor: + return x[..., 1] + + +class CrossEntropySemiring(LogSemiring): + r""" + Cross Entropy expectation semiring :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`, + where :math:`\oplus` computes the log-values and the running distributional cross entropy :math:`H[p,q]` + of the two distributions :cite:`li-eisner-2009-first`. + """ + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return cls.sum(torch.stack((x, y)), 0) + + @classmethod + def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + p = x[..., :-1].logsumexp(dim) + r = x[..., :-1] - p.unsqueeze(dim) + r = r[..., 0].exp().mul((x[..., -1] - r[..., 1])).sum(dim) + return torch.cat((p, r.unsqueeze(-1)), -1) + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim) + + @classmethod + def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim) + + @classmethod + def zero_(cls, x: torch.Tensor) -> torch.Tensor: + x[..., :-1].fill_(cls.zero) + x[..., -1].fill_(cls.one) + return x + + @classmethod + def one_(cls, x: torch.Tensor) -> torch.Tensor: + return x.fill_(cls.one) + + @classmethod + def convert(cls, x: torch.Tensor) -> torch.Tensor: + return torch.cat((x, cls.one_(torch.empty_like(x[..., :1]))), -1) + + @classmethod + def unconvert(cls, x: torch.Tensor) -> torch.Tensor: + return x[..., -1] + + +class KLDivergenceSemiring(LogSemiring): + r""" + KL divergence expectation semiring :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`, + where :math:`\oplus` computes the log-values and the running distributional KL divergence :math:`KL[p \parallel q]` + of the two distributions :cite:`li-eisner-2009-first`. + """ + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return cls.sum(torch.stack((x, y)), 0) + + @classmethod + def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + p = x[..., :-1].logsumexp(dim) + r = x[..., :-1] - p.unsqueeze(dim) + r = r[..., 0].exp().mul((x[..., -1] - r[..., 1] + r[..., 0])).sum(dim) + return torch.cat((p, r.unsqueeze(-1)), -1) + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim) + + @classmethod + def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim) + + @classmethod + def zero_(cls, x: torch.Tensor) -> torch.Tensor: + x[..., :-1].fill_(cls.zero) + x[..., -1].fill_(cls.one) + return x + + @classmethod + def one_(cls, x: torch.Tensor) -> torch.Tensor: + return x.fill_(cls.one) + + @classmethod + def convert(cls, x: torch.Tensor) -> torch.Tensor: + return torch.cat((x, cls.one_(torch.empty_like(x[..., :1]))), -1) + + @classmethod + def unconvert(cls, x: torch.Tensor) -> torch.Tensor: + return x[..., -1] + + +class SampledSemiring(LogSemiring): + r""" + Sampling semiring :math:`<\mathrm{logsumexp}, +, -\infty, 0>`, + which is an exact forward-filtering, backward-sampling approach. + """ + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return cls.sum(torch.stack((x, y)), 0) + + @classmethod + def sum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return sampled_logsumexp(x, dim) + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim) + + @classmethod + def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim) + + +class SparsemaxSemiring(LogSemiring): + r""" + Sparsemax semiring :math:`<\mathrm{sparsemax}, +, -\infty, 0>` + :cite:`martins-etal-2016-sparsemax,mensch-etal-2018-dp,correia-etal-2020-efficient`. + """ + + @classmethod + def add(cls, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return cls.sum(torch.stack((x, y)), 0) + + @staticmethod + def sum(x: torch.Tensor, dim: int = -1) -> torch.Tensor: + p = sparsemax(x, dim) + return x.mul(p).sum(dim) - p.norm(p=2, dim=dim) + + @classmethod + def cumsum(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.add(x, y))), dim) + + @classmethod + def cumprod(cls, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + return torch.stack(list(itertools.accumulate(x.unbind(dim), lambda x, y: cls.mul(x, y))), dim) diff --git a/tania_scripts/supar/structs/tree.py b/tania_scripts/supar/structs/tree.py new file mode 100644 index 0000000000000000000000000000000000000000..e1cc5554f9fbb1476916a31f25f3bc9821ee71d6 --- /dev/null +++ b/tania_scripts/supar/structs/tree.py @@ -0,0 +1,635 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from supar.structs.dist import StructuredDistribution +from supar.structs.fn import mst +from supar.structs.semiring import LogSemiring, Semiring +from supar.utils.fn import diagonal_stripe, expanded_stripe, stripe +from torch.distributions.utils import lazy_property + + +class MatrixTree(StructuredDistribution): + r""" + MatrixTree for calculating partitions and marginals of non-projective dependency trees in :math:`O(n^3)` + by an adaptation of Kirchhoff's MatrixTree Theorem :cite:`koo-etal-2007-structured`. + + Args: + scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible dependent-head pairs. + lens (~torch.LongTensor): ``[batch_size]``. + Sentence lengths for masking, regardless of root positions. Default: ``None``. + multiroot (bool): + If ``False``, requires the tree to contain only a single root. Default: ``True``. + + Examples: + >>> from supar import MatrixTree + >>> batch_size, seq_len = 2, 5 + >>> lens = torch.tensor([3, 4]) + >>> arcs = torch.tensor([[0, 2, 0, 4, 2], [0, 3, 1, 0, 3]]) + >>> s1 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens) + >>> s2 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens) + >>> s1.max + tensor([0.7174, 3.7910], grad_fn=<SumBackward1>) + >>> s1.argmax + tensor([[0, 0, 1, 1, 0], + [0, 4, 1, 0, 3]]) + >>> s1.log_partition + tensor([2.0229, 6.0558], grad_fn=<CopyBackwards>) + >>> s1.log_prob(arcs) + tensor([-3.2209, -2.5756], grad_fn=<SubBackward0>) + >>> s1.entropy + tensor([1.9711, 3.4497], grad_fn=<SubBackward0>) + >>> s1.kl(s2) + tensor([1.3354, 2.6914], grad_fn=<AddBackward0>) + """ + + def __init__( + self, + scores: torch.Tensor, + lens: Optional[torch.LongTensor] = None, + multiroot: bool = False + ) -> MatrixTree: + super().__init__(scores) + + batch_size, seq_len, *_ = scores.shape + self.lens = scores.new_full((batch_size,), seq_len-1).long() if lens is None else lens + self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len))) + self.mask = self.mask.index_fill(1, self.lens.new_tensor(0), 0) + + self.multiroot = multiroot + + def __repr__(self): + return f"{self.__class__.__name__}(multiroot={self.multiroot})" + + def __add__(self, other): + return MatrixTree(torch.stack((self.scores, other.scores)), self.lens, self.multiroot) + + @lazy_property + def max(self): + arcs = self.argmax + return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1) + + @lazy_property + def argmax(self): + with torch.no_grad(): + return mst(self.scores, self.mask, self.multiroot) + + def kmax(self, k: int) -> torch.Tensor: + # TODO: Camerini algorithm + raise NotImplementedError + + def sample(self): + raise NotImplementedError + + @lazy_property + def entropy(self): + return self.log_partition - (self.marginals * self.scores).sum((-1, -2)) + + def cross_entropy(self, other: MatrixTree) -> torch.Tensor: + return other.log_partition - (self.marginals * other.scores).sum((-1, -2)) + + def kl(self, other: MatrixTree) -> torch.Tensor: + return other.log_partition - self.log_partition + (self.marginals * (self.scores - other.scores)).sum((-1, -2)) + + def score(self, value: torch.LongTensor, partial: bool = False) -> torch.Tensor: + arcs = value + if partial: + mask, lens = self.mask, self.lens + mask = mask.index_fill(1, self.lens.new_tensor(0), 1) + mask = mask.unsqueeze(1) & mask.unsqueeze(2) + arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1) + arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0) + scores = LogSemiring.zero_mask(self.scores, ~(arcs & mask)) + return self.__class__(scores, lens, **self.kwargs).log_partition + return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1) + + @torch.enable_grad() + def forward(self, semiring: Semiring) -> torch.Tensor: + s_arc = self.scores + batch_size, *_ = s_arc.shape + mask, lens = self.mask.index_fill(1, self.lens.new_tensor(0), 1), self.lens + # double precision to prevent overflows + s_arc = semiring.zero_mask(s_arc, ~(mask.unsqueeze(-1) & mask.unsqueeze(-2))).double() + + # A(i, j) = exp(s(i, j)) + m = s_arc.view(batch_size, -1).max(-1)[0] + A = torch.exp(s_arc - m.view(-1, 1, 1)) + + # Weighted degree matrix + # D(i, j) = sum_j(A(i, j)), if h == m + # 0, otherwise + D = torch.zeros_like(A) + D.diagonal(0, 1, 2).copy_(A.sum(-1)) + # Laplacian matrix + # L(i, j) = D(i, j) - A(i, j) + L = D - A + if not self.multiroot: + L.diagonal(0, 1, 2).add_(-A[..., 0]) + L[..., 1] = A[..., 0] + L = nn.init.eye_(torch.empty_like(A[0])).repeat(batch_size, 1, 1).masked_scatter_(mask.unsqueeze(-1), L[mask]) + L = L + nn.init.eye_(torch.empty_like(A[0])) * torch.finfo().tiny + # Z = L^(0, 0), the minor of L w.r.t row 0 and column 0 + return (L[:, 1:, 1:].logdet() + m * lens).float() + + +class DependencyCRF(StructuredDistribution): + r""" + First-order TreeCRF for projective dependency trees :cite:`eisner-2000-bilexical,zhang-etal-2020-efficient`. + + Args: + scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all possible dependent-head pairs. + lens (~torch.LongTensor): ``[batch_size]``. + Sentence lengths for masking, regardless of root positions. Default: ``None``. + multiroot (bool): + If ``False``, requires the tree to contain only a single root. Default: ``True``. + + Examples: + >>> from supar import DependencyCRF + >>> batch_size, seq_len = 2, 5 + >>> lens = torch.tensor([3, 4]) + >>> arcs = torch.tensor([[0, 2, 0, 4, 2], [0, 3, 1, 0, 3]]) + >>> s1 = DependencyCRF(torch.randn(batch_size, seq_len, seq_len), lens) + >>> s2 = DependencyCRF(torch.randn(batch_size, seq_len, seq_len), lens) + >>> s1.max + tensor([3.6346, 1.7194], grad_fn=<IndexBackward>) + >>> s1.argmax + tensor([[0, 2, 3, 0, 0], + [0, 0, 3, 1, 1]]) + >>> s1.log_partition + tensor([4.1007, 3.3383], grad_fn=<IndexBackward>) + >>> s1.log_prob(arcs) + tensor([-1.3866, -5.5352], grad_fn=<SubBackward0>) + >>> s1.entropy + tensor([0.9979, 2.6056], grad_fn=<IndexBackward>) + >>> s1.kl(s2) + tensor([1.6631, 2.6558], grad_fn=<IndexBackward>) + """ + + def __init__( + self, + scores: torch.Tensor, + lens: Optional[torch.LongTensor] = None, + multiroot: bool = False + ) -> DependencyCRF: + super().__init__(scores) + + batch_size, seq_len, *_ = scores.shape + self.lens = scores.new_full((batch_size,), seq_len-1).long() if lens is None else lens + self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len))) + self.mask = self.mask.index_fill(1, self.lens.new_tensor(0), 0) + + self.multiroot = multiroot + + def __repr__(self): + return f"{self.__class__.__name__}(multiroot={self.multiroot})" + + def __add__(self, other): + return DependencyCRF(torch.stack((self.scores, other.scores), -1), self.lens, self.multiroot) + + @lazy_property + def argmax(self): + return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, torch.where(self.backward(self.max.sum()))[2]) + + def topk(self, k: int) -> torch.LongTensor: + preds = torch.stack([torch.where(self.backward(i))[2] for i in self.kmax(k).sum(0)], -1) + return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds) + + def score(self, value: torch.Tensor, partial: bool = False) -> torch.Tensor: + arcs = value + if partial: + mask, lens = self.mask, self.lens + mask = mask.index_fill(1, self.lens.new_tensor(0), 1) + mask = mask.unsqueeze(1) & mask.unsqueeze(2) + arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1) + arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0) + scores = LogSemiring.zero_mask(self.scores, ~(arcs & mask)) + return self.__class__(scores, lens, **self.kwargs).log_partition + return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1) + + def forward(self, semiring: Semiring) -> torch.Tensor: + s_arc = self.scores + batch_size, seq_len = s_arc.shape[:2] + # [seq_len, seq_len, batch_size, ...], (h->m) + s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0))) + s_i = semiring.zeros_like(s_arc) + s_c = semiring.zeros_like(s_arc) + semiring.one_(s_c.diagonal().movedim(-1, 1)) + + for w in range(1, seq_len): + n = seq_len - w + + # [n, batch_size, ...] + il = ir = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1) + # INCOMPLETE-L: I(j->i) = <C(i->r), C(j->r+1)> * s(j->i), i <= r < j + # fill the w-th diagonal of the lower triangular part of s_i with I(j->i) of n spans + s_i.diagonal(-w).copy_(semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1)) + # INCOMPLETE-R: I(i->j) = <C(i->r), C(j->r+1)> * s(i->j), i <= r < j + # fill the w-th diagonal of the upper triangular part of s_i with I(i->j) of n spans + s_i.diagonal(w).copy_(semiring.mul(ir, s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1)) + + # [n, batch_size, ...] + # COMPLETE-L: C(j->i) = <C(r->i), I(j->r)>, i <= r < j + cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1) + s_c.diagonal(-w).copy_(cl.movedim(0, -1)) + # COMPLETE-R: C(i->j) = <I(i->r), C(r->j)>, i < r <= j + cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1) + s_c.diagonal(w).copy_(cr.movedim(0, -1)) + if not self.multiroot: + s_c[0, w][self.lens.ne(w)] = semiring.zero + return semiring.unconvert(s_c)[0][self.lens, range(batch_size)] + + +class Dependency2oCRF(StructuredDistribution): + r""" + Second-order TreeCRF for projective dependency trees :cite:`mcdonald-pereira-2006-online,zhang-etal-2020-efficient`. + + Args: + scores (tuple(~torch.Tensor, ~torch.Tensor)): + Scores of all possible dependent-head pairs (``[batch_size, seq_len, seq_len]``) and + dependent-head-sibling triples ``[batch_size, seq_len, seq_len, seq_len]``. + lens (~torch.LongTensor): ``[batch_size]``. + Sentence lengths for masking, regardless of root positions. Default: ``None``. + multiroot (bool): + If ``False``, requires the tree to contain only a single root. Default: ``True``. + + Examples: + >>> from supar import Dependency2oCRF + >>> batch_size, seq_len = 2, 5 + >>> lens = torch.tensor([3, 4]) + >>> arcs = torch.tensor([[0, 2, 0, 4, 2], [0, 3, 1, 0, 3]]) + >>> sibs = torch.tensor([CoNLL.get_sibs(i) for i in arcs[:, 1:].tolist()]) + >>> s1 = Dependency2oCRF((torch.randn(batch_size, seq_len, seq_len), + torch.randn(batch_size, seq_len, seq_len, seq_len)), + lens) + >>> s2 = Dependency2oCRF((torch.randn(batch_size, seq_len, seq_len), + torch.randn(batch_size, seq_len, seq_len, seq_len)), + lens) + >>> s1.max + tensor([0.7574, 3.3634], grad_fn=<IndexBackward>) + >>> s1.argmax + tensor([[0, 3, 3, 0, 0], + [0, 4, 4, 4, 0]]) + >>> s1.log_partition + tensor([1.9906, 4.3599], grad_fn=<IndexBackward>) + >>> s1.log_prob((arcs, sibs)) + tensor([-0.6975, -6.2845], grad_fn=<SubBackward0>) + >>> s1.entropy + tensor([1.6436, 2.1717], grad_fn=<IndexBackward>) + >>> s1.kl(s2) + tensor([0.4929, 2.0759], grad_fn=<IndexBackward>) + """ + + def __init__( + self, + scores: Tuple[torch.Tensor, torch.Tensor], + lens: Optional[torch.LongTensor] = None, + multiroot: bool = False + ) -> Dependency2oCRF: + super().__init__(scores) + + batch_size, seq_len, *_ = scores[0].shape + self.lens = scores[0].new_full((batch_size,), seq_len-1).long() if lens is None else lens + self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len))) + self.mask = self.mask.index_fill(1, self.lens.new_tensor(0), 0) + + self.multiroot = multiroot + + def __repr__(self): + return f"{self.__class__.__name__}(multiroot={self.multiroot})" + + def __add__(self, other): + return Dependency2oCRF([torch.stack((i, j), -1) for i, j in zip(self.scores, other.scores)], self.lens, self.multiroot) + + @lazy_property + def argmax(self): + return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, + torch.where(self.backward(self.max.sum())[0])[2]) + + def topk(self, k: int) -> torch.LongTensor: + preds = torch.stack([torch.where(self.backward(i)[0])[2] for i in self.kmax(k).sum(0)], -1) + return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds) + + def score(self, value: Tuple[torch.LongTensor, torch.LongTensor], partial: bool = False) -> torch.Tensor: + arcs, sibs = value + if partial: + mask, lens = self.mask, self.lens + mask = mask.index_fill(1, self.lens.new_tensor(0), 1) + mask = mask.unsqueeze(1) & mask.unsqueeze(2) + arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1) + arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0) + s_arc, s_sib = LogSemiring.zero_mask(self.scores[0], ~(arcs & mask)), self.scores[1] + return self.__class__((s_arc, s_sib), lens, **self.kwargs).log_partition + s_arc = self.scores[0].gather(-1, arcs.unsqueeze(-1)).squeeze(-1) + s_arc = LogSemiring.prod(LogSemiring.one_mask(s_arc, ~self.mask), -1) + s_sib = self.scores[1].gather(-1, sibs.unsqueeze(-1)).squeeze(-1) + s_sib = LogSemiring.prod(LogSemiring.one_mask(s_sib, ~sibs.gt(0)), (-1, -2)) + return LogSemiring.mul(s_arc, s_sib) + + @torch.enable_grad() + def forward(self, semiring: Semiring) -> torch.Tensor: + s_arc, s_sib = self.scores + batch_size, seq_len = s_arc.shape[:2] + # [seq_len, seq_len, batch_size, ...], (h->m) + s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0))) + # [seq_len, seq_len, seq_len, batch_size, ...], (h->m->s) + s_sib = semiring.convert(s_sib.movedim((0, 2), (3, 0))) + s_i = semiring.zeros_like(s_arc) + s_s = semiring.zeros_like(s_arc) + s_c = semiring.zeros_like(s_arc) + semiring.one_(s_c.diagonal().movedim(-1, 1)) + + for w in range(1, seq_len): + n = seq_len - w + + # INCOMPLETE-L: I(j->i) = <I(j->r), S(j->r, i)> * s(j->i), i < r < j + # <C(j->j), C(i->j-1)> * s(j->i), otherwise + # [n, w, batch_size, ...] + il = semiring.times(stripe(s_i, n, w, (w, 1)), + stripe(s_s, n, w, (1, 0), 0), + stripe(s_sib[range(w, n+w), range(n), :], n, w, (0, 1))) + il[:, -1] = semiring.mul(stripe(s_c, n, 1, (w, w)), stripe(s_c, n, 1, (0, w - 1))).squeeze(1) + il = semiring.sum(il, 1) + s_i.diagonal(-w).copy_(semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1)) + # INCOMPLETE-R: I(i->j) = <I(i->r), S(i->r, j)> * s(i->j), i < r < j + # <C(i->i), C(j->i+1)> * s(i->j), otherwise + # [n, w, batch_size, ...] + ir = semiring.times(stripe(s_i, n, w), + stripe(s_s, n, w, (0, w), 0), + stripe(s_sib[range(n), range(w, n+w), :], n, w)) + if not self.multiroot: + semiring.zero_(ir[0]) + ir[:, 0] = semiring.mul(stripe(s_c, n, 1), stripe(s_c, n, 1, (w, 1))).squeeze(1) + ir = semiring.sum(ir, 1) + s_i.diagonal(w).copy_(semiring.mul(ir, s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1)) + + # [batch_size, ..., n] + sl = sr = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1).movedim(0, -1) + # SIB: S(j, i) = <C(i->r), C(j->r+1)>, i <= r < j + s_s.diagonal(-w).copy_(sl) + # SIB: S(i, j) = <C(i->r), C(j->r+1)>, i <= r < j + s_s.diagonal(w).copy_(sr) + + # [n, batch_size, ...] + # COMPLETE-L: C(j->i) = <C(r->i), I(j->r)>, i <= r < j + cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1) + s_c.diagonal(-w).copy_(cl.movedim(0, -1)) + # COMPLETE-R: C(i->j) = <I(i->r), C(r->j)>, i < r <= j + cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1) + s_c.diagonal(w).copy_(cr.movedim(0, -1)) + return semiring.unconvert(s_c)[0][self.lens, range(batch_size)] + + +class ConstituencyCRF(StructuredDistribution): + r""" + Constituency TreeCRF :cite:`zhang-etal-2020-fast,stern-etal-2017-minimal`. + + Args: + scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. + Scores of all constituents. + lens (~torch.LongTensor): ``[batch_size]``. + Sentence lengths for masking. + + Examples: + >>> from supar import ConstituencyCRF + >>> batch_size, seq_len, n_labels = 2, 5, 4 + >>> lens = torch.tensor([3, 4]) + >>> charts = torch.tensor([[[-1, 0, -1, 0, -1], + [-1, -1, 0, 0, -1], + [-1, -1, -1, 0, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1]], + [[-1, 0, 0, -1, 0], + [-1, -1, 0, -1, -1], + [-1, -1, -1, 0, 0], + [-1, -1, -1, -1, 0], + [-1, -1, -1, -1, -1]]]) + >>> s1 = ConstituencyCRF(torch.randn(batch_size, seq_len, seq_len, n_labels), lens, True) + >>> s2 = ConstituencyCRF(torch.randn(batch_size, seq_len, seq_len, n_labels), lens, True) + >>> s1.max + tensor([3.7036, 7.2569], grad_fn=<IndexBackward0>) + >>> s1.argmax + [[[0, 1, 2], [0, 3, 0], [1, 2, 1], [1, 3, 0], [2, 3, 3]], + [[0, 1, 1], [0, 4, 2], [1, 2, 3], [1, 4, 1], [2, 3, 2], [2, 4, 3], [3, 4, 3]]] + >>> s1.log_partition + tensor([ 8.5394, 12.9940], grad_fn=<IndexBackward0>) + >>> s1.log_prob(charts) + tensor([ -8.5209, -14.1160], grad_fn=<SubBackward0>) + >>> s1.entropy + tensor([6.8868, 9.3996], grad_fn=<IndexBackward0>) + >>> s1.kl(s2) + tensor([4.0039, 4.1037], grad_fn=<IndexBackward0>) + """ + + def __init__( + self, + scores: torch.Tensor, + lens: Optional[torch.LongTensor] = None, + label: bool = False + ) -> ConstituencyCRF: + super().__init__(scores) + + batch_size, seq_len, *_ = scores.shape + self.lens = scores.new_full((batch_size,), seq_len-1).long() if lens is None else lens + self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len))) + self.mask = self.mask.unsqueeze(1) & scores.new_ones(scores.shape[:3]).bool().triu_(1) + self.label = label + + def __repr__(self): + return f"{self.__class__.__name__}(label={self.label})" + + def __add__(self, other): + return ConstituencyCRF(torch.stack((self.scores, other.scores), -1), self.lens, self.label) + + @lazy_property + def argmax(self): + return [torch.nonzero(i).tolist() for i in self.backward(self.max.sum())] + + def topk(self, k: int) -> List[List[Tuple]]: + return list(zip(*[[torch.nonzero(j).tolist() for j in self.backward(i)] for i in self.kmax(k).sum(0)])) + + def score(self, value: torch.LongTensor) -> torch.Tensor: + mask = self.mask & value.ge(0) + if self.label: + scores = self.scores[mask].gather(-1, value[mask].unsqueeze(-1)).squeeze(-1) + scores = torch.full_like(mask, LogSemiring.one, dtype=scores.dtype).masked_scatter_(mask, scores) + else: + scores = LogSemiring.one_mask(self.scores, ~mask) + return LogSemiring.prod(LogSemiring.prod(scores, -1), -1) + + @torch.enable_grad() + def forward(self, semiring: Semiring) -> torch.Tensor: + batch_size, seq_len = self.scores.shape[:2] + # [seq_len, seq_len, batch_size, ...], (l->r) + scores = semiring.convert(self.scores.movedim((1, 2), (0, 1))) + scores = semiring.sum(scores, 3) if self.label else scores + s = semiring.zeros_like(scores) + s.diagonal(1).copy_(scores.diagonal(1)) + + for w in range(2, seq_len): + n = seq_len - w + # [n, batch_size, ...] + s_s = semiring.dot(stripe(s, n, w-1, (0, 1)), stripe(s, n, w-1, (1, w), False), 1) + s.diagonal(w).copy_(semiring.mul(s_s, scores.diagonal(w).movedim(-1, 0)).movedim(0, -1)) + return semiring.unconvert(s)[0][self.lens, range(batch_size)] + + +class BiLexicalizedConstituencyCRF(StructuredDistribution): + r""" + Grammarless Eisner-Satta Algorithm :cite:`eisner-satta-1999-efficient,yang-etal-2021-neural`. + + Code is revised from `Songlin Yang's implementation <https://github.com/sustcsonglin/span-based-dependency-parsing>`_. + + Args: + scores (~torch.Tensor): ``[2, batch_size, seq_len, seq_len]``. + Scores of dependencies and constituents. + lens (~torch.LongTensor): ``[batch_size]``. + Sentence lengths for masking. + + Examples: + >>> from supar import BiLexicalizedConstituencyCRF + >>> batch_size, seq_len = 2, 5 + >>> lens = torch.tensor([3, 4]) + >>> deps = torch.tensor([[0, 0, 1, 1, 0], [0, 3, 1, 0, 3]]) + >>> cons = torch.tensor([[[0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[0, 1, 1, 1, 1], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 0]]]).bool() + >>> heads = torch.tensor([[[0, 1, 1, 1, 0], + [0, 0, 2, 0, 0], + [0, 0, 0, 3, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[0, 1, 1, 3, 3], + [0, 0, 2, 0, 0], + [0, 0, 0, 3, 0], + [0, 0, 0, 0, 4], + [0, 0, 0, 0, 0]]]) + >>> s1 = BiLexicalizedConstituencyCRF((torch.randn(batch_size, seq_len, seq_len), + torch.randn(batch_size, seq_len, seq_len), + torch.randn(batch_size, seq_len, seq_len, seq_len)), + lens) + >>> s2 = BiLexicalizedConstituencyCRF((torch.randn(batch_size, seq_len, seq_len), + torch.randn(batch_size, seq_len, seq_len), + torch.randn(batch_size, seq_len, seq_len, seq_len)), + lens) + >>> s1.max + tensor([0.5792, 2.1737], grad_fn=<MaxBackward0>) + >>> s1.argmax[0] + tensor([[0, 3, 1, 0, 0], + [0, 4, 1, 1, 0]]) + >>> s1.argmax[1] + [[[0, 3], [0, 2], [0, 1], [1, 2], [2, 3]], [[0, 4], [0, 3], [0, 2], [0, 1], [1, 2], [2, 3], [3, 4]]] + >>> s1.log_partition + tensor([1.1923, 3.2343], grad_fn=<LogsumexpBackward>) + >>> s1.log_prob((deps, cons, heads)) + tensor([-1.9123, -3.6127], grad_fn=<SubBackward0>) + >>> s1.entropy + tensor([1.3376, 2.2996], grad_fn=<SelectBackward>) + >>> s1.kl(s2) + tensor([1.0617, 2.7839], grad_fn=<SelectBackward>) + """ + + def __init__( + self, + scores: List[torch.Tensor], + lens: Optional[torch.LongTensor] = None + ) -> BiLexicalizedConstituencyCRF: + super().__init__(scores) + + batch_size, seq_len, *_ = scores[1].shape + self.lens = scores[1].new_full((batch_size,), seq_len-1).long() if lens is None else lens + self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len))) + self.mask = self.mask.unsqueeze(1) & scores[1].new_ones(scores[1].shape[:3]).bool().triu_(1) + + def __add__(self, other): + return BiLexicalizedConstituencyCRF([torch.stack((i, j), -1) for i, j in zip(self.scores, other.scores)], self.lens) + + @lazy_property + def argmax(self): + marginals = self.backward(self.max.sum()) + dep_mask = self.mask[:, 0] + dep = self.lens.new_zeros(dep_mask.shape).masked_scatter_(dep_mask, torch.where(marginals[0])[2]) + con = [torch.nonzero(i).tolist() for i in marginals[1]] + return dep, con + + def topk(self, k: int) -> Tuple[torch.LongTensor, List[List[Tuple]]]: + dep_mask = self.mask[:, 0] + marginals = [self.backward(i) for i in self.kmax(k).sum(0)] + dep_preds = torch.stack([torch.where(i)[2] for i in marginals[0]], -1) + dep_preds = self.lens.new_zeros(*dep_mask.shape, k).masked_scatter_(dep_mask.unsqueeze(-1), dep_preds) + con_preds = list(zip(*[[torch.nonzero(j).tolist() for j in i] for i in marginals[1]])) + return dep_preds, con_preds + + def score(self, value: List[Union[torch.LongTensor, torch.BoolTensor]], partial: bool = False) -> torch.Tensor: + deps, cons, heads = value + s_dep, s_con, s_head = self.scores + mask, lens = self.mask, self.lens + dep_mask, con_mask = mask[:, 0], mask + if partial: + if deps is not None: + dep_mask = dep_mask.index_fill(1, self.lens.new_tensor(0), 1) + dep_mask = dep_mask.unsqueeze(1) & dep_mask.unsqueeze(2) + deps = deps.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1) + deps = deps.eq(lens.new_tensor(range(mask.shape[1]))) | deps.lt(0) + s_dep = LogSemiring.zero_mask(s_dep, ~(deps & dep_mask)) + if cons is not None: + s_con = LogSemiring.zero_mask(s_con, ~(cons & con_mask)) + if heads is not None: + head_mask = heads.unsqueeze(-1).eq(lens.new_tensor(range(mask.shape[1]))) + head_mask = head_mask & con_mask.unsqueeze(-1) + s_head = LogSemiring.zero_mask(s_head, ~head_mask) + return self.__class__((s_dep, s_con, s_head), lens, **self.kwargs).log_partition + s_dep = LogSemiring.prod(LogSemiring.one_mask(s_dep.gather(-1, deps.unsqueeze(-1)).squeeze(-1), ~dep_mask), -1) + s_head = LogSemiring.mul(s_con, s_head.gather(-1, heads.unsqueeze(-1)).squeeze(-1)) + s_head = LogSemiring.prod(LogSemiring.prod(LogSemiring.one_mask(s_head, ~(con_mask & cons)), -1), -1) + return LogSemiring.mul(s_dep, s_head) + + def forward(self, semiring: Semiring) -> torch.Tensor: + s_dep, s_con, s_head = self.scores + batch_size, seq_len, *_ = s_con.shape + # [seq_len, seq_len, batch_size, ...], (m<-h) + s_dep = semiring.convert(s_dep.movedim(0, 2)) + s_root, s_dep = s_dep[1:, 0], s_dep[1:, 1:] + # [seq_len, seq_len, batch_size, ...], (i, j) + s_con = semiring.convert(s_con.movedim(0, 2)) + # [seq_len, seq_len, seq_len-1, batch_size, ...], (i, j, h) + s_head = semiring.mul(s_con.unsqueeze(2), semiring.convert(s_head.movedim(0, -1)[:, :, 1:])) + # [seq_len, seq_len, seq_len-1, batch_size, ...], (i, j, h) + s_span = semiring.zeros_like(s_head) + # [seq_len, seq_len, seq_len-1, batch_size, ...], (i, j<-h) + s_hook = semiring.zeros_like(s_head) + diagonal_stripe(s_span, 1).copy_(diagonal_stripe(s_head, 1)) + s_hook.diagonal(1).copy_(semiring.mul(s_dep, diagonal_stripe(s_head, 1)).movedim(0, -1)) + + for w in range(2, seq_len): + n = seq_len - w + # COMPLETE-L: s_span_l(i, j, h) = <s_span(i, k, h), s_hook(h->k, j)>, i < k < j + # [n, w, batch_size, ...] + s_l = stripe(semiring.dot(stripe(s_span, n, w-1, (0, 1)), stripe(s_hook, n, w-1, (1, w), False), 1), n, w) + # COMPLETE-R: s_span_r(i, j, h) = <s_hook(i, k<-h), s_span(k, j, h)>, i < k < j + # [n, w, batch_size, ...] + s_r = stripe(semiring.dot(stripe(s_hook, n, w-1, (0, 1)), stripe(s_span, n, w-1, (1, w), False), 1), n, w) + # COMPLETE: s_span(i, j, h) = (s_span_l(i, j, h) + s_span_r(i, j, h)) * s(i, j, h) + # [n, w, batch_size, ...] + s = semiring.mul(semiring.sum(torch.stack((s_l, s_r)), 0), diagonal_stripe(s_head, w)) + diagonal_stripe(s_span, w).copy_(s) + + if w == seq_len - 1: + continue + # ATTACH: s_hook(h->i, j) = <s(h->m), s_span(i, j, m)>, i <= m < j + # [n, seq_len, batch_size, ...] + s = semiring.dot(expanded_stripe(s_dep, n, w), diagonal_stripe(s_span, w).unsqueeze(2), 1) + s_hook.diagonal(w).copy_(s.movedim(0, -1)) + return semiring.unconvert(semiring.dot(s_span[0][self.lens, :, range(batch_size)].transpose(0, 1), s_root, 0)) diff --git a/tania_scripts/supar/structs/vi.py b/tania_scripts/supar/structs/vi.py new file mode 100644 index 0000000000000000000000000000000000000000..b848825e661ea0ecd0a258ba6c9498b8d7624a23 --- /dev/null +++ b/tania_scripts/supar/structs/vi.py @@ -0,0 +1,499 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import List, Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from supar.structs import DependencyCRF +from supar.utils.common import MIN + + +class DependencyMFVI(nn.Module): + r""" + Mean Field Variational Inference for approximately calculating marginals + of dependency trees :cite:`wang-tu-2020-second`. + """ + + def __init__(self, max_iter: int = 3) -> DependencyMFVI: + super().__init__() + + self.max_iter = max_iter + + def __repr__(self): + return f"{self.__class__.__name__}(max_iter={self.max_iter})" + + @torch.enable_grad() + def forward( + self, + scores: List[torch.Tensor], + mask: torch.BoolTensor, + target: Optional[torch.LongTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + scores (~torch.Tensor, ~torch.Tensor): + Tuple of three tensors `s_arc` and `s_sib`. + `s_arc` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs. + `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask to avoid aggregation on padding tokens. + target (~torch.LongTensor): ``[batch_size, seq_len]``. + A Tensor of gold-standard dependent-head pairs. Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. + The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. + """ + + logits = self.mfvi(*scores, mask) + marginals = logits.softmax(-1) + + if target is None: + return marginals + loss = F.cross_entropy(logits[mask], target[mask]) + + return loss, marginals + + def mfvi(self, s_arc, s_sib, mask): + batch_size, seq_len = mask.shape + ls, rs = torch.stack(torch.where(mask.new_ones(seq_len, seq_len))).view(-1, seq_len, seq_len).sort(0)[0] + mask = mask.index_fill(1, ls.new_tensor(0), 1) + # [seq_len, seq_len, batch_size], (h->m) + mask = (mask.unsqueeze(-1) & mask.unsqueeze(-2)).permute(2, 1, 0) + # [seq_len, seq_len, seq_len, batch_size], (h->m->s) + mask2o = mask.unsqueeze(1) & mask.unsqueeze(2) + mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1) + mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1) + # [seq_len, seq_len, batch_size], (h->m) + s_arc = s_arc.permute(2, 1, 0) + # [seq_len, seq_len, seq_len, batch_size], (h->m->s) + s_sib = s_sib.permute(2, 1, 3, 0) * mask2o + + # posterior distributions + # [seq_len, seq_len, batch_size], (h->m) + q = s_arc + + for _ in range(self.max_iter): + q = q.softmax(0) + # q(ij) = s(ij) + sum(q(ik)s^sib(ij,ik)), k != i,j + q = s_arc + (q.unsqueeze(1) * s_sib).sum(2) + + return q.permute(2, 1, 0) + + +class DependencyLBP(nn.Module): + r""" + Loopy Belief Propagation for approximately calculating marginals + of dependency trees :cite:`smith-eisner-2008-dependency`. + """ + + def __init__(self, max_iter: int = 3) -> DependencyLBP: + super().__init__() + + self.max_iter = max_iter + + def __repr__(self): + return f"{self.__class__.__name__}(max_iter={self.max_iter})" + + @torch.enable_grad() + def forward( + self, + scores: List[torch.Tensor], + mask: torch.BoolTensor, + target: Optional[torch.LongTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + scores (~torch.Tensor, ~torch.Tensor): + Tuple of three tensors `s_arc` and `s_sib`. + `s_arc` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs. + `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples. + mask (~torch.BoolTensor): ``[batch_size, seq_len]``. + The mask to avoid aggregation on padding tokens. + target (~torch.LongTensor): ``[batch_size, seq_len]``. + A Tensor of gold-standard dependent-head pairs. Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. + The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. + """ + + logits = self.lbp(*scores, mask) + marginals = logits.softmax(-1) + + if target is None: + return marginals + loss = F.cross_entropy(logits[mask], target[mask]) + + return loss, marginals + + def lbp(self, s_arc, s_sib, mask): + batch_size, seq_len = mask.shape + ls, rs = torch.stack(torch.where(mask.new_ones(seq_len, seq_len))).view(-1, seq_len, seq_len).sort(0)[0] + mask = mask.index_fill(1, ls.new_tensor(0), 1) + # [seq_len, seq_len, batch_size], (h->m) + mask = (mask.unsqueeze(-1) & mask.unsqueeze(-2)).permute(2, 1, 0) + # [seq_len, seq_len, seq_len, batch_size], (h->m->s) + mask2o = mask.unsqueeze(1) & mask.unsqueeze(2) + mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1) + mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1) + # [seq_len, seq_len, batch_size], (h->m) + s_arc = s_arc.permute(2, 1, 0) + # [seq_len, seq_len, seq_len, batch_size], (h->m->s) + s_sib = s_sib.permute(2, 1, 3, 0).masked_fill_(~mask2o, MIN) + + # log beliefs + # [seq_len, seq_len, batch_size], (h->m) + q = s_arc + # [seq_len, seq_len, seq_len, batch_size], (h->m->s) + m_sib = s_sib.new_zeros(seq_len, seq_len, seq_len, batch_size) + + for _ in range(self.max_iter): + q = q.log_softmax(0) + # m(ik->ij) = logsumexp(q(ik) - m(ij->ik) + s(ij->ik)) + m = q.unsqueeze(2) - m_sib + # TODO: better solution for OOM + m_sib = torch.logaddexp(m.logsumexp(0), m + s_sib).transpose(1, 2).log_softmax(0) + # q(ij) = s(ij) + sum(m(ik->ij)), k != i,j + q = s_arc + (m_sib * mask2o).sum(2) + + return q.permute(2, 1, 0) + + +class ConstituencyMFVI(nn.Module): + r""" + Mean Field Variational Inference for approximately calculating marginals of constituent trees. + """ + + def __init__(self, max_iter: int = 3) -> ConstituencyMFVI: + super().__init__() + + self.max_iter = max_iter + + def __repr__(self): + return f"{self.__class__.__name__}(max_iter={self.max_iter})" + + @torch.enable_grad() + def forward( + self, + scores: List[torch.Tensor], + mask: torch.BoolTensor, + target: Optional[torch.LongTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + scores (~torch.Tensor, ~torch.Tensor): + Tuple of two tensors `s_span` and `s_pair`. + `s_span` (``[batch_size, seq_len, seq_len]``) holds scores of all possible spans. + `s_pair` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of second-order triples. + mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + The mask to avoid aggregation on padding tokens. + target (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + A Tensor of gold-standard dependent-head pairs. Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. + The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. + """ + + logits = self.mfvi(*scores, mask) + marginals = logits.sigmoid() + + if target is None: + return marginals + loss = F.binary_cross_entropy_with_logits(logits[mask], target[mask].float()) + + return loss, marginals + + def mfvi(self, s_span, s_pair, mask): + batch_size, seq_len, _ = mask.shape + ls, rs = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len).sort(0)[0] + # [seq_len, seq_len, batch_size], (l->r) + mask = mask.movedim(0, 2) + # [seq_len, seq_len, seq_len, batch_size], (l->r->b) + mask2o = mask.unsqueeze(2).repeat(1, 1, seq_len, 1) + mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1) + mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1) + # [seq_len, seq_len, batch_size], (l->r) + s_span = s_span.movedim(0, 2) + # [seq_len, seq_len, seq_len, batch_size], (l->r->b) + s_pair = s_pair.permute(1, 2, 3, 0) * mask2o + + # posterior distributions + # [seq_len, seq_len, batch_size], (l->r) + q = s_span + + for _ in range(self.max_iter): + q = q.sigmoid() + # q(ij) = s(ij) + sum(q(jk)*s^pair(ij,jk), k != i,j + q = s_span + (q.unsqueeze(1) * s_pair).sum(2) + + return q.permute(2, 0, 1) + + +class ConstituencyLBP(nn.Module): + r""" + Loopy Belief Propagation for approximately calculating marginals of constituent trees. + """ + + def __init__(self, max_iter: int = 3) -> ConstituencyLBP: + super().__init__() + + self.max_iter = max_iter + + def __repr__(self): + return f"{self.__class__.__name__}(max_iter={self.max_iter})" + + @torch.enable_grad() + def forward( + self, + scores: List[torch.Tensor], + mask: torch.BoolTensor, + target: Optional[torch.LongTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + scores (~torch.Tensor, ~torch.Tensor): + Tuple of four tensors `s_edge`, `s_sib`, `s_cop` and `s_grd`. + `s_span` (``[batch_size, seq_len, seq_len]``) holds scores of all possible spans. + `s_pair` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of second-order triples. + mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + The mask to avoid aggregation on padding tokens. + target (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + A Tensor of gold-standard dependent-head pairs. Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. + The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. + """ + + logits = self.lbp(*scores, mask) + marginals = logits.softmax(-1)[..., 1] + + if target is None: + return marginals + loss = F.cross_entropy(logits[mask], target[mask].long()) + + return loss, marginals + + def lbp(self, s_span, s_pair, mask): + batch_size, seq_len, _ = mask.shape + ls, rs = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len).sort(0)[0] + # [seq_len, seq_len, batch_size], (l->r) + mask = mask.movedim(0, 2) + # [seq_len, seq_len, seq_len, batch_size], (l->r->b) + mask2o = mask.unsqueeze(2).repeat(1, 1, seq_len, 1) + mask2o = mask2o & ls.unsqueeze(-1).ne(ls.new_tensor(range(seq_len))).unsqueeze(-1) + mask2o = mask2o & rs.unsqueeze(-1).ne(rs.new_tensor(range(seq_len))).unsqueeze(-1) + # [2, seq_len, seq_len, batch_size], (l->r) + s_span = torch.stack((torch.zeros_like(s_span), s_span)).permute(0, 3, 2, 1) + # [seq_len, seq_len, seq_len, batch_size], (l->r->p) + s_pair = s_pair.permute(2, 1, 3, 0) + + # log beliefs + # [2, seq_len, seq_len, batch_size], (h->m) + q = s_span + # [2, seq_len, seq_len, seq_len, batch_size], (h->m->s) + m_pair = s_pair.new_zeros(2, seq_len, seq_len, seq_len, batch_size) + + for _ in range(self.max_iter): + q = q.log_softmax(0) + # m(ik->ij) = logsumexp(q(ik) - m(ij->ik) + s(ij->ik)) + m = q.unsqueeze(3) - m_pair + m_pair = torch.stack((m.logsumexp(0), torch.stack((m[0], m[1] + s_pair)).logsumexp(0))).log_softmax(0) + # q(ij) = s(ij) + sum(m(ik->ij)), k != i,j + q = s_span + (m_pair.transpose(2, 3) * mask2o).sum(3) + + return q.permute(3, 2, 1, 0) + + +class SemanticDependencyMFVI(nn.Module): + r""" + Mean Field Variational Inference for approximately calculating marginals + of semantic dependency trees :cite:`wang-etal-2019-second`. + """ + + def __init__(self, max_iter: int = 3) -> SemanticDependencyMFVI: + super().__init__() + + self.max_iter = max_iter + + def __repr__(self): + return f"{self.__class__.__name__}(max_iter={self.max_iter})" + + @torch.enable_grad() + def forward( + self, + scores: List[torch.Tensor], + mask: torch.BoolTensor, + target: Optional[torch.LongTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + scores (~torch.Tensor, ~torch.Tensor): + Tuple of four tensors `s_edge`, `s_sib`, `s_cop` and `s_grd`. + `s_edge` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs. + `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples. + `s_cop` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-coparent triples. + `s_grd` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-grandparent triples. + mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + The mask to avoid aggregation on padding tokens. + target (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. + A Tensor of gold-standard dependent-head pairs. Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. + The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. + """ + + logits = self.mfvi(*scores, mask) + marginals = logits.sigmoid() + + if target is None: + return marginals + loss = F.binary_cross_entropy_with_logits(logits[mask], target[mask].float()) + + return loss, marginals + + def mfvi(self, s_edge, s_sib, s_cop, s_grd, mask): + _, seq_len, _ = mask.shape + hs, ms = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len) + # [seq_len, seq_len, batch_size], (h->m) + mask = mask.permute(2, 1, 0) + # [seq_len, seq_len, seq_len, batch_size], (h->m->s) + mask2o = mask.unsqueeze(1) & mask.unsqueeze(2) + mask2o = mask2o & hs.unsqueeze(-1).ne(hs.new_tensor(range(seq_len))).unsqueeze(-1) + mask2o = mask2o & ms.unsqueeze(-1).ne(ms.new_tensor(range(seq_len))).unsqueeze(-1) + mask2o.diagonal().fill_(0) + # [seq_len, seq_len, batch_size], (h->m) + s_edge = s_edge.permute(2, 1, 0) + # [seq_len, seq_len, seq_len, batch_size], (h->m->s) + s_sib = s_sib.permute(2, 1, 3, 0) * mask2o + # [seq_len, seq_len, seq_len, batch_size], (h->m->c) + s_cop = s_cop.permute(2, 1, 3, 0) * mask2o + # [seq_len, seq_len, seq_len, batch_size], (h->m->g) + s_grd = s_grd.permute(2, 1, 3, 0) * mask2o + + # posterior distributions + # [seq_len, seq_len, batch_size], (h->m) + q = s_edge + + for _ in range(self.max_iter): + q = q.sigmoid() + # q(ij) = s(ij) + sum(q(ik)s^sib(ij,ik) + q(kj)s^cop(ij,kj) + q(jk)s^grd(ij,jk)), k != i,j + q = s_edge + (q.unsqueeze(1) * s_sib + q.transpose(0, 1).unsqueeze(0) * s_cop + q.unsqueeze(0) * s_grd).sum(2) + + return q.permute(2, 1, 0) + + +class SemanticDependencyLBP(nn.Module): + r""" + Loopy Belief Propagation for approximately calculating marginals + of semantic dependency trees :cite:`wang-etal-2019-second`. + """ + + def __init__(self, max_iter: int = 3) -> SemanticDependencyLBP: + super().__init__() + + self.max_iter = max_iter + + def __repr__(self): + return f"{self.__class__.__name__}(max_iter={self.max_iter})" + + @torch.enable_grad() + def forward( + self, + scores: List[torch.Tensor], + mask: torch.BoolTensor, + target: Optional[torch.LongTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + scores (~torch.Tensor, ~torch.Tensor): + Tuple of four tensors `s_edge`, `s_sib`, `s_cop` and `s_grd`. + `s_edge` (``[batch_size, seq_len, seq_len]``) holds scores of all possible dependent-head pairs. + `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples. + `s_cop` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-coparent triples. + `s_grd` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-grandparent triples. + mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``. + The mask to avoid aggregation on padding tokens. + target (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``. + A Tensor of gold-standard dependent-head pairs. Default: ``None``. + + Returns: + ~torch.Tensor, ~torch.Tensor: + The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``. + The second is a tensor for marginals of shape ``[batch_size, seq_len, seq_len]``. + """ + + logits = self.lbp(*scores, mask) + marginals = logits.softmax(-1)[..., 1] + + if target is None: + return marginals + loss = F.cross_entropy(logits[mask], target[mask]) + + return loss, marginals + + def lbp(self, s_edge, s_sib, s_cop, s_grd, mask): + lens = mask[..., 0].sum(1) + _, seq_len, _ = mask.shape + hs, ms = torch.stack(torch.where(torch.ones_like(mask[0]))).view(-1, seq_len, seq_len) + # [seq_len, seq_len, batch_size], (h->m) + mask = mask.permute(2, 1, 0) + # [seq_len, seq_len, seq_len, batch_size], (h->m->s) + mask2o = mask.unsqueeze(1) & mask.unsqueeze(2) + mask2o = mask2o & hs.unsqueeze(-1).ne(hs.new_tensor(range(seq_len))).unsqueeze(-1) + mask2o = mask2o & ms.unsqueeze(-1).ne(ms.new_tensor(range(seq_len))).unsqueeze(-1) + mask2o.diagonal().fill_(0) + # [2, seq_len, seq_len, batch_size], (h->m) + s_edge = torch.stack((torch.zeros_like(s_edge), s_edge)).permute(0, 3, 2, 1) + # [seq_len, seq_len, seq_len, batch_size], (h->m->s) + s_sib = s_sib.permute(2, 1, 3, 0) + # [seq_len, seq_len, seq_len, batch_size], (h->m->c) + s_cop = s_cop.permute(2, 1, 3, 0) + # [seq_len, seq_len, seq_len, batch_size], (h->m->g) + s_grd = s_grd.permute(2, 1, 3, 0) + + # log beliefs + # [2, seq_len, seq_len, batch_size], (h->m) + q = s_edge + # sibling factor + # [2, seq_len, seq_len, seq_len, batch_size], (h->m->s) + m_sib = s_sib.new_zeros(2, *mask2o.shape) + # coparent factor + # [2, seq_len, seq_len, seq_len, batch_size], (h->m->c) + m_cop = s_cop.new_zeros(2, *mask2o.shape) + # grandparent factor + # [2, seq_len, seq_len, seq_len, batch_size], (h->m->g) + m_grd = s_grd.new_zeros(2, *mask2o.shape) + # tree factor + # [2, seq_len, seq_len, batch_size], (h->m) + m_tree = torch.zeros_like(s_edge) + + for _ in range(self.max_iter): + # sibling factor + v_sib = q.unsqueeze(2) - m_sib + m_sib = torch.stack((v_sib.logsumexp(0), torch.stack((v_sib[0], v_sib[1] + s_sib)).logsumexp(0))).log_softmax(0) + # coparent factor + v_cop = q.transpose(1, 2).unsqueeze(1) - m_cop + m_cop = torch.stack((v_cop.logsumexp(0), torch.stack((v_cop[0], v_cop[1] + s_cop)).logsumexp(0))).log_softmax(0) + # grandparent factor + v_grd = q.unsqueeze(1) - m_grd + m_grd = torch.stack((v_grd.logsumexp(0), torch.stack((v_grd[0], v_grd[1] + s_grd)).logsumexp(0))).log_softmax(0) + # tree factor + v_tree = q - m_tree + b_tree = DependencyCRF((v_tree[1] - v_tree[0]).permute(2, 1, 0), lens).marginals.permute(2, 1, 0) + b_tree = torch.stack((1 - b_tree, b_tree)) + m_tree = (b_tree.clamp(torch.finfo().eps).log() - v_tree).log_softmax(0) + # q(ij) = s(ij) + sum(m(ik->ij)), k != i,j + q = s_edge + ((m_sib + m_cop + m_grd).transpose(2, 3) * mask2o).sum(3) + m_tree + + return q.permute(3, 2, 1, 0) diff --git a/tania_scripts/supar/utils/__init__.py b/tania_scripts/supar/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..279bb3fee548100e6538d245e4adf226c6a41a67 --- /dev/null +++ b/tania_scripts/supar/utils/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +from . import field, fn, metric, transform +from .config import Config +from .data import Dataset +from .embed import Embedding +from .field import ChartField, Field, RawField, SubwordField +from .transform import Transform +from .vocab import Vocab + +__all__ = ['Config', + 'Dataset', + 'Embedding', + 'RawField', 'Field', 'SubwordField', 'ChartField', + 'Transform', + 'Vocab', + 'field', 'fn', 'metric', 'transform'] diff --git a/tania_scripts/supar/utils/__pycache__/__init__.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7085981158edde053f2e613913ae5104f375bc5 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/__init__.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea9d342a076da2ffb2ad39c0137168c37c8ba0a9 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/common.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d3aa81c201d0bcf97afeb8657a7fbfda08ae766 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/common.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/common.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a95bf475fc0c1b5ed5fbe397eed664cc93daf29 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/common.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/config.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ca40a0457ad92a8c30fc0f85221363889e017cd Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/config.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/config.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c702557431115d22e5cd81b51b3e043b2e3df185 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/config.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/data.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86c2665f120944e63bad5a419675ff4d4c8cdfdc Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/data.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/data.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6e9fad9bc068980a106abf82b540436629543ca Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/data.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/embed.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06e59dca4554b13420c432192952564973ebd013 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/embed.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/embed.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/embed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50a582efc082811b6496eeb81a30e62f80b7f484 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/embed.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/field.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/field.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67521b6cccf3e5ff86e13752f33222767dc10c66 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/field.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/field.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/field.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..448dcecae208170106c7fee557c7e348e4ae066f Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/field.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/fn.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d4582deab9dc73c2dbc769f06c95e6ab6002084 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/fn.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/fn.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b894334d6aceb12d3be2fb5c1afb8363709ba8d3 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/fn.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/logging.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/logging.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d47f1e45e93b38ad9f76e9acb3bef318cf82085f Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/logging.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/logging.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/logging.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f52d49b3cbff680cb6322a30b5b1986941d883a6 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/logging.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/metric.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/metric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eb49a374e350bfdb86f44ddc9bb384ce41fe121 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/metric.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/metric.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/metric.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d685d90d287c335bc5300af1b9cb989c3df9b65 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/metric.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/optim.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/optim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf0d056ae56044d19b2e5e290f45058425bc2af3 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/optim.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/optim.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/optim.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f34b83e7bcd94eda11b8b415630ab6d562cf1c79 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/optim.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/parallel.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d12400edeb441199d157172e73705a9fc94003fb Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/parallel.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/parallel.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/parallel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c67aea66c324dbdeb02c9220e11cd0318b49f8 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/parallel.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb268708f3777028bfa8541824a2616c07fc8b52 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37128ac854fdfde72e08c626b1a6cda83b86d673 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/tokenizer.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/transform.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa125d981767ec90dad382896177289419f0d04a Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/transform.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/transform.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd491045274e55fd2e1572e2fcbdf56acaec5e61 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/transform.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/vocab.cpython-310.pyc b/tania_scripts/supar/utils/__pycache__/vocab.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89fc3d6b458d28f39f20492fa93b0fd4981e51fd Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/vocab.cpython-310.pyc differ diff --git a/tania_scripts/supar/utils/__pycache__/vocab.cpython-311.pyc b/tania_scripts/supar/utils/__pycache__/vocab.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37a8843712eaa44afd7d45ac81dfeef4e4038e50 Binary files /dev/null and b/tania_scripts/supar/utils/__pycache__/vocab.cpython-311.pyc differ diff --git a/tania_scripts/supar/utils/common.py b/tania_scripts/supar/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f320d2909e3bb86a83721d6ce820545295ccca3d --- /dev/null +++ b/tania_scripts/supar/utils/common.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +import os + +PAD = '<pad>' +UNK = '<unk>' +BOS = '<bos>' +EOS = '<eos>' +NUL = '<nul>' + +MIN = -1e32 +INF = float('inf') + +CACHE = os.path.expanduser('~/.cache/supar') diff --git a/tania_scripts/supar/utils/config.py b/tania_scripts/supar/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..28b18adcab745b57c6a1000090510f82b986062c --- /dev/null +++ b/tania_scripts/supar/utils/config.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import yaml +import os +from ast import literal_eval +from configparser import ConfigParser +from typing import Any, Dict, Optional, Sequence + +import supar +from omegaconf import OmegaConf +from supar.utils.fn import download + + +class Config(object): + + def __init__(self, **kwargs: Any) -> None: + super(Config, self).__init__() + + self.update(kwargs) + + def __repr__(self) -> str: + return yaml.dump(self.__dict__) + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __contains__(self, key: str) -> bool: + return hasattr(self, key) + + def __getstate__(self) -> Dict[str, Any]: + return self.__dict__ + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + + @property + def primitive_config(self) -> Dict[str, Any]: + from enum import Enum + from pathlib import Path + primitive_types = (int, float, bool, str, bytes, Enum, Path) + return {name: value for name, value in self.__dict__.items() if type(value) in primitive_types} + + def keys(self) -> Any: + return self.__dict__.keys() + + def items(self) -> Any: + return self.__dict__.items() + + def update(self, kwargs: Dict[str, Any]) -> Config: + for key in ('self', 'cls', '__class__'): + kwargs.pop(key, None) + kwargs.update(kwargs.pop('kwargs', dict())) + for name, value in kwargs.items(): + setattr(self, name, value) + return self + + def get(self, key: str, default: Optional[Any] = None) -> Any: + return getattr(self, key, default) + + def pop(self, key: str, default: Optional[Any] = None) -> Any: + return self.__dict__.pop(key, default) + + def save(self, path): + with open(path, 'w') as f: + f.write(str(self)) + + @classmethod + def load(cls, conf: str = '', unknown: Optional[Sequence[str]] = None, **kwargs: Any) -> Config: + if conf and not os.path.exists(conf): + conf = download(supar.CONFIG['github'].get(conf, conf)) + if conf.endswith(('.yml', '.yaml')): + config = OmegaConf.load(conf) + else: + config = ConfigParser() + config.read(conf) + config = dict((name, literal_eval(value)) for s in config.sections() for name, value in config.items(s)) + if unknown is not None: + parser = argparse.ArgumentParser() + for name, value in config.items(): + parser.add_argument('--'+name.replace('_', '-'), type=type(value), default=value) + config.update(vars(parser.parse_args(unknown))) + return cls(**config).update(kwargs) diff --git a/tania_scripts/supar/utils/data.py b/tania_scripts/supar/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..587b32d60f78b39cc2e6c425fe69c54f0919e61d --- /dev/null +++ b/tania_scripts/supar/utils/data.py @@ -0,0 +1,344 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import itertools +import os +import queue +import shutil +import tempfile +import threading +from contextlib import contextmanager +from typing import Dict, Iterable, List, Union + +import pathos.multiprocessing as mp +import torch +import torch.distributed as dist +from supar.utils.common import INF +from supar.utils.fn import binarize, debinarize, kmeans +from supar.utils.logging import get_logger, progress_bar +from supar.utils.parallel import is_dist, is_master +from supar.utils.transform import Batch, Transform +from torch.distributions.utils import lazy_property + +logger = get_logger(__name__) + + +class Dataset(torch.utils.data.Dataset): + r""" + Dataset that is compatible with :class:`torch.utils.data.Dataset`, serving as a wrapper for manipulating all data fields + with the operating behaviours defined in :class:`~supar.utils.transform.Transform`. + The data fields of all the instantiated sentences can be accessed as an attribute of the dataset. + + Args: + transform (Transform): + An instance of :class:`~supar.utils.transform.Transform` or its derivations. + The instance holds a series of loading and processing behaviours with regard to the specific data format. + data (Union[str, Iterable]): + A filename or a list of instances that will be passed into :meth:`transform.load`. + cache (bool): + If ``True``, tries to use the previously cached binarized data for fast loading. + In this way, sentences are loaded on-the-fly according to the meta data. + If ``False``, all sentences will be directly loaded into the memory. + Default: ``False``. + binarize (bool): + If ``True``, binarizes the dataset once building it. Only works if ``cache=True``. Default: ``False``. + bin (str): + Path for saving binarized files, required if ``cache=True``. Default: ``None``. + max_len (int): + Sentences exceeding the length will be discarded. Default: ``None``. + kwargs (Dict): + Together with `data`, kwargs will be passed into :meth:`transform.load` to control the loading behaviour. + + Attributes: + transform (Transform): + An instance of :class:`~supar.utils.transform.Transform`. + sentences (List[Sentence]): + A list of sentences loaded from the data. + Each sentence includes fields obeying the data format defined in ``transform``. + If ``cache=True``, each is a pointer to the sentence stored in the cache file. + """ + + def __init__( + self, + transform: Transform, + data: Union[str, Iterable], + cache: bool = False, + binarize: bool = False, + bin: str = None, + max_len: int = None, + **kwargs + ) -> Dataset: + super(Dataset, self).__init__() + + self.transform = transform + self.data = data + self.cache = cache + self.binarize = binarize + self.bin = bin + self.max_len = max_len or INF + self.kwargs = kwargs + + if cache: + if not isinstance(data, str) or not os.path.exists(data): + raise FileNotFoundError("Only files are allowed for binarization, but not found") + if self.bin is None: + self.fbin = data + '.pt' + else: + os.makedirs(self.bin, exist_ok=True) + self.fbin = os.path.join(self.bin, os.path.split(data)[1]) + '.pt' + if not self.binarize and os.path.exists(self.fbin): + try: + self.sentences = debinarize(self.fbin, meta=True)['sentences'] + except Exception: + raise RuntimeError(f"Error found while debinarizing {self.fbin}, which may have been corrupted. " + "Try re-binarizing it first!") + else: + self.sentences = list(transform.load(data, **kwargs)) + + + def __repr__(self): + s = f"{self.__class__.__name__}(" + s += f"n_sentences={len(self.sentences)}" + if hasattr(self, 'loader'): + s += f", n_batches={len(self.loader)}" + if hasattr(self, 'buckets'): + s += f", n_buckets={len(self.buckets)}" + if self.cache: + s += f", cache={self.cache}" + if self.binarize: + s += f", binarize={self.binarize}" + if self.max_len < INF: + s += f", max_len={self.max_len}" + s += ")" + return s + + def __len__(self): + return len(self.sentences) + + def __getitem__(self, index): + return debinarize(self.fbin, self.sentences[index]) if self.cache else self.sentences[index] + + def __getattr__(self, name): + if name not in {f.name for f in self.transform.flattened_fields}: + raise AttributeError + if self.cache: + if os.path.exists(self.fbin) and not self.binarize: + sentences = self + else: + sentences = self.transform.load(self.data, **self.kwargs) + return (getattr(sentence, name) for sentence in sentences) + return [getattr(sentence, name) for sentence in self.sentences] + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + + @lazy_property + def sizes(self): + if not self.cache: + return [s.size for s in self.sentences] + return debinarize(self.fbin, 'sizes') + + def build( + self, + batch_size: int, + n_buckets: int = 1, + shuffle: bool = False, + distributed: bool = False, + even: bool = True, + n_workers: int = 0, + pin_memory: bool = True, + chunk_size: int = 1000, + ) -> Dataset: + # numericalize all fields + if not self.cache: + self.sentences = [i for i in self.transform(self.sentences) if len(i) < self.max_len] + else: + # if not forced to do binarization and the binarized file already exists, directly load the meta file + if os.path.exists(self.fbin) and not self.binarize: + self.sentences = debinarize(self.fbin, meta=True)['sentences'] + else: + @contextmanager + def cache(sentences): + ftemp = tempfile.mkdtemp() + fs = os.path.join(ftemp, 'sentences') + fb = os.path.join(ftemp, os.path.basename(self.fbin)) + global global_transform + global_transform = self.transform + sentences = binarize({'sentences': progress_bar(sentences)}, fs)[1]['sentences'] + try: + yield ((sentences[s:s+chunk_size], fs, f"{fb}.{i}", self.max_len) + for i, s in enumerate(range(0, len(sentences), chunk_size))) + finally: + del global_transform + shutil.rmtree(ftemp) + + def numericalize(sentences, fs, fb, max_len): + sentences = global_transform((debinarize(fs, sentence) for sentence in sentences)) + sentences = [i for i in sentences if len(i) < max_len] + return binarize({'sentences': sentences, 'sizes': [sentence.size for sentence in sentences]}, fb)[0] + + logger.info(f"Seeking to cache the data to {self.fbin} first") + # numericalize the fields of each sentence + if is_master(): + with cache(self.transform.load(self.data, **self.kwargs)) as chunks, mp.Pool(32) as pool: + results = [pool.apply_async(numericalize, chunk) for chunk in chunks] + self.sentences = binarize((r.get() for r in results), self.fbin, merge=True)[1]['sentences'] + if is_dist(): + dist.barrier() + if not is_master(): + self.sentences = debinarize(self.fbin, meta=True)['sentences'] + # NOTE: the final bucket count is roughly equal to n_buckets + self.buckets = dict(zip(*kmeans(self.sizes, n_buckets))) + self.loader = DataLoader(transform=self.transform, + dataset=self, + batch_sampler=Sampler(self.buckets, batch_size, shuffle, distributed, even), + num_workers=n_workers, + collate_fn=collate_fn, + pin_memory=pin_memory) + return self + + +class Sampler(torch.utils.data.Sampler): + r""" + Sampler that supports for bucketization and token-level batchification. + + Args: + buckets (Dict): + A dict that maps each centroid to indices of clustered sentences. + The centroid corresponds to the average length of all sentences in the bucket. + batch_size (int): + Token-level batch size. The resulting batch contains roughly the same number of tokens as ``batch_size``. + shuffle (bool): + If ``True``, the sampler will shuffle both buckets and samples in each bucket. Default: ``False``. + distributed (bool): + If ``True``, the sampler will be used in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` + that restricts data loading to a subset of the dataset. + Default: ``False``. + even (bool): + If ``True``, the sampler will add extra indices to make the data evenly divisible across the replicas. + Default: ``True``. + """ + + def __init__( + self, + buckets: Dict[float, List], + batch_size: int, + shuffle: bool = False, + distributed: bool = False, + even: bool = True + ) -> Sampler: + self.batch_size = batch_size + self.shuffle = shuffle + self.distributed = distributed + self.even = even + #print("237 utils data buckets items", buckets.items()) + self.sizes, self.buckets = zip(*[(size, bucket) for size, bucket in buckets.items()]) + # number of batches in each bucket, clipped by range [1, len(bucket)] + self.n_batches = [min(len(bucket), max(round(size * len(bucket) / batch_size), 1)) + for size, bucket in zip(self.sizes, self.buckets)] + self.rank, self.n_replicas, self.n_samples = 0, 1, self.n_total_samples + + if distributed: + self.rank = dist.get_rank() + self.n_replicas = dist.get_world_size() + self.n_samples = self.n_total_samples // self.n_replicas + if self.n_total_samples % self.n_replicas != 0: + self.n_samples += 1 if even else int(self.rank < self.n_total_samples % self.n_replicas) + self.epoch = 1 + + def __iter__(self): + g = torch.Generator() + g.manual_seed(self.epoch) + self.epoch += 1 + + total, batches = 0, [] + # if `shuffle=True`, shuffle both the buckets and samples in each bucket + # for distributed training, make sure each process generates the same random sequence at each epoch + range_fn = torch.arange if not self.shuffle else lambda x: torch.randperm(x, generator=g) + for i in itertools.cycle(range(len(self.buckets))): + bucket = self.buckets[i] + split_sizes = [(len(bucket) - j - 1) // self.n_batches[i] + 1 for j in range(self.n_batches[i])] + # DON'T use `torch.chunk` which may return wrong number of batches + for batch in range_fn(len(bucket)).split(split_sizes): + #print('supar utils 270', batch) + if total % self.n_replicas == self.rank: + batches.append([bucket[j] for j in batch.tolist()]) + if len(batches) == self.n_samples: + return iter(batches[i] for i in range_fn(self.n_samples).tolist()) + total += 1 + + def __len__(self): + return self.n_samples + + @property + def n_total_samples(self): + return sum(self.n_batches) + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + +class DataLoader(torch.utils.data.DataLoader): + + r""" + A wrapper for native :class:`torch.utils.data.DataLoader` enhanced with a data prefetcher. + See http://stackoverflow.com/questions/7323664/python-generator-pre-fetch and + https://github.com/NVIDIA/apex/issues/304. + """ + + def __init__(self, transform, **kwargs): + super().__init__(**kwargs) + + self.transform = transform + + def __iter__(self): + return PrefetchGenerator(self.transform, super().__iter__()) + + +class PrefetchGenerator(threading.Thread): + + def __init__(self, transform, loader, prefetch=1): + threading.Thread.__init__(self) + + self.transform = transform + + self.queue = queue.Queue(prefetch) + self.loader = loader + self.daemon = True + if torch.cuda.is_available(): + self.stream = torch.cuda.Stream() + + self.start() + + def __iter__(self): + return self + + def __next__(self): + if hasattr(self, 'stream'): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.queue.get() + if batch is None: + raise StopIteration + return batch + + def run(self): + # `torch.cuda.current_device` is thread local + # see https://github.com/pytorch/pytorch/issues/56588 + if is_dist() and torch.cuda.is_available(): + torch.cuda.set_device(dist.get_rank()) + if hasattr(self, 'stream'): + with torch.cuda.stream(self.stream): + for batch in self.loader: + self.queue.put(batch.compose(self.transform)) + else: + for batch in self.loader: + self.queue.put(batch.compose(self.transform)) + self.queue.put(None) + + +def collate_fn(x): + return Batch(x) diff --git a/tania_scripts/supar/utils/embed.py b/tania_scripts/supar/utils/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..c132b4da9627425576d339bf57cfdb15ed94aa89 --- /dev/null +++ b/tania_scripts/supar/utils/embed.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +from typing import Iterable, Optional, Union + +import torch +from supar.utils.common import CACHE +from supar.utils.fn import download +from supar.utils.logging import progress_bar +from torch.distributions.utils import lazy_property + + +class Embedding(object): + r""" + Defines a container object for holding pretrained embeddings. + This object is callable and behaves like :class:`torch.nn.Embedding`. + For huge files, this object supports lazy loading, seeking to retrieve vectors from the disk on the fly if necessary. + + Currently available embeddings: + - `GloVe`_ + - `Fasttext`_ + - `Giga`_ + - `Tencent`_ + + Args: + path (str): + Path to the embedding file or short name registered in ``supar.utils.embed.PRETRAINED``. + unk (Optional[str]): + The string token used to represent OOV tokens. Default: ``None``. + skip_first (bool) + If ``True``, skips the first line of the embedding file. Default: ``False``. + cache (bool): + If ``True``, instead of loading entire embeddings into memory, seeks to load vectors from the disk once called. + Default: ``True``. + sep (str): + Separator used by embedding file. Default: ``' '``. + + Examples: + >>> import torch.nn as nn + >>> from supar.utils.embed import Embedding + >>> glove = Embedding.load('glove-6b-100') + >>> glove + GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True) + >>> fasttext = Embedding.load('fasttext-en') + >>> fasttext + FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True) + >>> giga = Embedding.load('giga-100') + >>> giga + GigaEmbedding(n_tokens=372846, dim=100, cache=True) + >>> indices = torch.tensor([glove.vocab[i.lower()] for i in ['She', 'enjoys', 'playing', 'tennis', '.']]) + >>> indices + tensor([ 67, 8371, 697, 2140, 2]) + >>> glove(indices).shape + torch.Size([5, 100]) + >>> glove(indices).equal(nn.Embedding.from_pretrained(glove.vectors)(indices)) + True + + .. _GloVe: + https://nlp.stanford.edu/projects/glove/ + .. _Fasttext: + https://fasttext.cc/docs/en/crawl-vectors.html + .. _Giga: + https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip + .. _Tencent: + https://ai.tencent.com/ailab/nlp/zh/download.html + """ + + CACHE = os.path.join(CACHE, 'data/embeds') + + def __init__( + self, + path: str, + unk: Optional[str] = None, + skip_first: bool = False, + cache: bool = True, + sep: str = ' ', + **kwargs + ) -> Embedding: + super().__init__() + + self.path = path + self.unk = unk + self.skip_first = skip_first + self.cache = cache + self.sep = sep + self.kwargs = kwargs + + self.vocab = {token: i for i, token in enumerate(self.tokens)} + + def __len__(self): + return len(self.vocab) + + def __repr__(self): + s = f"{self.__class__.__name__}(" + s += f"n_tokens={len(self)}, dim={self.dim}" + if self.unk is not None: + s += f", unk={self.unk}" + if self.skip_first: + s += f", skip_first={self.skip_first}" + if self.cache: + s += f", cache={self.cache}" + s += ")" + return s + + def __contains__(self, token): + return token in self.vocab + + def __getitem__(self, key: Union[int, Iterable[int], torch.Tensor]) -> torch.Tensor: + indices = key + if not isinstance(indices, torch.Tensor): + indices = torch.tensor(key) + if self.cache: + elems, indices = indices.unique(return_inverse=True) + with open(self.path) as f: + vectors = [] + for index in elems.tolist(): + f.seek(self.positions[index]) + vectors.append(list(map(float, f.readline().strip().split(self.sep)[1:]))) + vectors = torch.tensor(vectors) + else: + vectors = self.vectors + return torch.embedding(vectors, indices) + + def __call__(self, key: Union[int, Iterable[int], torch.Tensor]) -> torch.Tensor: + return self[key] + + @lazy_property + def dim(self): + return len(self[0]) + + @lazy_property + def unk_index(self): + if self.unk is not None: + return self.vocab[self.unk] + raise AttributeError + + @lazy_property + def tokens(self): + with open(self.path) as f: + if self.skip_first: + f.readline() + return [line.strip().split(self.sep)[0] for line in progress_bar(f)] + + @lazy_property + def vectors(self): + with open(self.path) as f: + if self.skip_first: + f.readline() + return torch.tensor([list(map(float, line.strip().split(self.sep)[1:])) for line in progress_bar(f)]) + + @lazy_property + def positions(self): + with open(self.path) as f: + if self.skip_first: + f.readline() + positions = [f.tell()] + while True: + line = f.readline() + if line: + positions.append(f.tell()) + else: + break + return positions + + @classmethod + def load(cls, path: str, unk: Optional[str] = None, **kwargs) -> Embedding: + if path in PRETRAINED: + cfg = dict(**PRETRAINED[path]) + embed = cfg.pop('_target_') + return embed(**cfg, **kwargs) + return cls(path, unk, **kwargs) + + +class GloVeEmbedding(Embedding): + + r""" + `GloVe`_: Global Vectors for Word Representation. + Training is performed on aggregated global word-word co-occurrence statistics from a corpus, + and the resulting representations showcase interesting linear substructures of the word vector space. + + Args: + src (str): + Size of the source data for training. Default: ``6B``. + dim (int): + Which dimension of the embeddings to use. Default: 100. + reload (bool): + If ``True``, forces a fresh download. Default: ``False``. + + Examples: + >>> from supar.utils.embed import Embedding + >>> Embedding.load('glove-6b-100') + GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True) + + .. _GloVe: + https://nlp.stanford.edu/projects/glove/ + """ + + def __init__(self, src: str = '6B', dim: int = 100, reload=False, *args, **kwargs) -> GloVeEmbedding: + if src == '6B' or src == 'twitter.27B': + url = f'https://nlp.stanford.edu/data/glove.{src}.zip' + else: + url = f'https://nlp.stanford.edu/data/glove.{src}.{dim}d.zip' + path = os.path.join(os.path.join(self.CACHE, 'glove'), f'glove.{src}.{dim}d.txt') + if not os.path.exists(path) or reload: + download(url, os.path.join(self.CACHE, 'glove'), clean=True) + + super().__init__(path=path, unk='unk', *args, **kwargs, ) + + +class FasttextEmbedding(Embedding): + + r""" + `Fasttext`_ word embeddings for 157 languages, trained using CBOW, in dimension 300, + with character n-grams of length 5, a window of size 5 and 10 negatives. + + Args: + lang (str): + Language code. Default: ``en``. + reload (bool): + If ``True``, forces a fresh download. Default: ``False``. + + Examples: + >>> from supar.utils.embed import Embedding + >>> Embedding.load('fasttext-en') + FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True) + + .. _Fasttext: + https://fasttext.cc/docs/en/crawl-vectors.html + """ + + def __init__(self, lang: str = 'en', reload=False, *args, **kwargs) -> FasttextEmbedding: + url = f'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{lang}.300.vec.gz' + path = os.path.join(self.CACHE, 'fasttext', f'cc.{lang}.300.vec') + if not os.path.exists(path) or reload: + download(url, os.path.join(self.CACHE, 'fasttext'), clean=True) + + super().__init__(path=path, skip_first=True, *args, **kwargs) + + +class GigaEmbedding(Embedding): + + r""" + `Giga`_ word embeddings, trained on Chinese Gigaword Third Edition for Chinese using word2vec, + used by :cite:`zhang-etal-2020-efficient` and :cite:`zhang-etal-2020-fast`. + + Args: + reload (bool): + If ``True``, forces a fresh download. Default: ``False``. + + Examples: + >>> from supar.utils.embed import Embedding + >>> Embedding.load('giga-100') + GigaEmbedding(n_tokens=372846, dim=100, cache=True) + + .. _Giga: + https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip + """ + + def __init__(self, reload=False, *args, **kwargs) -> GigaEmbedding: + url = 'https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip' + path = os.path.join(self.CACHE, 'giga', 'giga.100.txt') + if not os.path.exists(path) or reload: + download(url, os.path.join(self.CACHE, 'giga'), clean=True) + + super().__init__(path=path, *args, **kwargs) + + +class TencentEmbedding(Embedding): + + r""" + `Tencent`_ word embeddings. + The embeddings are trained on large-scale text collected from news, webpages, and novels with Directional Skip-Gram. + 100-dimension and 200-dimension embeddings for over 12 million Chinese words are provided. + + Args: + dim (int): + Which dimension of the embeddings to use. Currently 100 and 200 are available. Default: 100. + large (bool): + If ``True``, uses large version with larger vocab size (12,287,933); 2,000,000 otherwise. Default: ``False``. + reload (bool): + If ``True``, forces a fresh download. Default: ``False``. + + Examples: + >>> from supar.utils.embed import Embedding + >>> Embedding.load('tencent-100') + TencentEmbedding(n_tokens=2000000, dim=100, skip_first=True, cache=True) + >>> Embedding.load('tencent-100-large') + TencentEmbedding(n_tokens=12287933, dim=100, skip_first=True, cache=True) + + .. _Tencent: + https://ai.tencent.com/ailab/nlp/zh/download.html + """ + + def __init__(self, dim: int = 100, large: bool = False, reload=False, *args, **kwargs) -> TencentEmbedding: + url = f'https://ai.tencent.com/ailab/nlp/zh/data/tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if large else "-s"}.tar.gz' # noqa + name = f'tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if large else "-s"}' + path = os.path.join(os.path.join(self.CACHE, 'tencent'), name, f'{name}.txt') + if not os.path.exists(path) or reload: + download(url, os.path.join(self.CACHE, 'tencent'), clean=True) + + super().__init__(path=path, skip_first=True, *args, **kwargs) + + +PRETRAINED = { + 'glove-6b-50': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 50}, + 'glove-6b-100': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 100}, + 'glove-6b-200': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 200}, + 'glove-6b-300': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 300}, + 'glove-42b-300': {'_target_': GloVeEmbedding, 'src': '42B', 'dim': 300}, + 'glove-840b-300': {'_target_': GloVeEmbedding, 'src': '84B', 'dim': 300}, + 'glove-twitter-27b-25': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 25}, + 'glove-twitter-27b-50': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 50}, + 'glove-twitter-27b-100': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 100}, + 'glove-twitter-27b-200': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 200}, + 'fasttext-bg': {'_target_': FasttextEmbedding, 'lang': 'bg'}, + 'fasttext-ca': {'_target_': FasttextEmbedding, 'lang': 'ca'}, + 'fasttext-cs': {'_target_': FasttextEmbedding, 'lang': 'cs'}, + 'fasttext-de': {'_target_': FasttextEmbedding, 'lang': 'de'}, + 'fasttext-en': {'_target_': FasttextEmbedding, 'lang': 'en'}, + 'fasttext-es': {'_target_': FasttextEmbedding, 'lang': 'es'}, + 'fasttext-fr': {'_target_': FasttextEmbedding, 'lang': 'fr'}, + 'fasttext-it': {'_target_': FasttextEmbedding, 'lang': 'it'}, + 'fasttext-nl': {'_target_': FasttextEmbedding, 'lang': 'nl'}, + 'fasttext-no': {'_target_': FasttextEmbedding, 'lang': 'no'}, + 'fasttext-ro': {'_target_': FasttextEmbedding, 'lang': 'ro'}, + 'fasttext-ru': {'_target_': FasttextEmbedding, 'lang': 'ru'}, + 'giga-100': {'_target_': GigaEmbedding}, + 'tencent-100': {'_target_': TencentEmbedding, 'dim': 100}, + 'tencent-100-large': {'_target_': TencentEmbedding, 'dim': 100, 'large': True}, + 'tencent-200': {'_target_': TencentEmbedding, 'dim': 200}, + 'tencent-200-large': {'_target_': TencentEmbedding, 'dim': 200, 'large': True}, +} diff --git a/tania_scripts/supar/utils/field.py b/tania_scripts/supar/utils/field.py new file mode 100644 index 0000000000000000000000000000000000000000..dd14bc78f7c45302b61c5b17ff50c90cbde22abd --- /dev/null +++ b/tania_scripts/supar/utils/field.py @@ -0,0 +1,415 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from collections import Counter +from typing import Callable, Iterable, List, Optional, Union + +import torch +from supar.utils.data import Dataset +from supar.utils.embed import Embedding +from supar.utils.fn import pad +from supar.utils.logging import progress_bar +from supar.utils.parallel import wait +from supar.utils.vocab import Vocab + + +class RawField(object): + r""" + Defines a general datatype. + + A :class:`RawField` object does not assume any property of the datatype and + it holds parameters relating to how a datatype should be processed. + + Args: + name (str): + The name of the field. + fn (function): + The function used for preprocessing the examples. Default: ``None``. + """ + + def __init__(self, name: str, fn: Optional[Callable] = None) -> RawField: + self.name = name + self.fn = fn + + def __repr__(self): + return f"({self.name}): {self.__class__.__name__}()" + + def preprocess(self, sequence: Iterable) -> Iterable: + return self.fn(sequence) if self.fn is not None else sequence + + def transform(self, sequences: Iterable[List]) -> Iterable[List]: + return (self.preprocess(seq) for seq in sequences) + + def compose(self, sequences: Iterable[List]) -> Iterable[List]: + return sequences + + +class Field(RawField): + r""" + Defines a datatype together with instructions for converting to :class:`~torch.Tensor`. + :class:`Field` models common text processing datatypes that can be represented by tensors. + It holds a :class:`~supar.utils.vocab.Vocab` object that defines the set of possible values + for elements of the field and their corresponding numerical representations. + The :class:`Field` object also holds other parameters relating to how a datatype + should be numericalized, such as a tokenization method. + + Args: + name (str): + The name of the field. + pad_token (str): + The string token used as padding. Default: ``None``. + unk_token (str): + The string token used to represent OOV words. Default: ``None``. + bos_token (str): + A token that will be prepended to every example using this field, or ``None`` for no `bos_token`. + Default: ``None``. + eos_token (str): + A token that will be appended to every example using this field, or ``None`` for no `eos_token`. + lower (bool): + Whether to lowercase the text in this field. Default: ``False``. + use_vocab (bool): + Whether to use a :class:`~supar.utils.vocab.Vocab` object. + If ``False``, the data in this field should already be numerical. + Default: ``True``. + tokenize (function): + The function used to tokenize strings using this field into sequential examples. Default: ``None``. + fn (function): + The function used for preprocessing the examples. Default: ``None``. + """ + + def __init__( + self, + name: str, + pad: Optional[str] = None, + unk: Optional[str] = None, + bos: Optional[str] = None, + eos: Optional[str] = None, + lower: bool = False, + use_vocab: bool = True, + tokenize: Optional[Callable] = None, + fn: Optional[Callable] = None, + delay: Optional[int] = 0 + ) -> Field: + self.name = name + self.pad = pad + self.unk = unk + self.bos = bos + self.eos = eos + self.lower = lower + self.use_vocab = use_vocab + self.tokenize = tokenize + self.fn = fn + self.delay = delay + + self.specials = [token for token in [pad, unk, bos, eos] if token is not None] + + def __repr__(self): + s, params = f"({self.name}): {self.__class__.__name__}(", [] + if hasattr(self, 'vocab'): + params.append(f"vocab_size={len(self.vocab)}") + if self.pad is not None: + params.append(f"pad={self.pad}") + if self.unk is not None: + params.append(f"unk={self.unk}") + if self.bos is not None: + params.append(f"bos={self.bos}") + if self.eos is not None: + params.append(f"eos={self.eos}") + if self.lower: + params.append(f"lower={self.lower}") + if not self.use_vocab: + params.append(f"use_vocab={self.use_vocab}") + if self.delay > 0: + params.append(f'delay={self.delay}') + return s + ', '.join(params) + ')' + + @property + def pad_index(self): + if self.pad is None: + return 0 + if hasattr(self, 'vocab'): + return self.vocab[self.pad] + return self.specials.index(self.pad) + + @property + def unk_index(self): + if self.unk is None: + return 0 + if hasattr(self, 'vocab'): + return self.vocab[self.unk] + return self.specials.index(self.unk) + + @property + def bos_index(self): + if hasattr(self, 'vocab'): + return self.vocab[self.bos] + return self.specials.index(self.bos) + + @property + def eos_index(self): + if hasattr(self, 'vocab'): + return self.vocab[self.eos] + return self.specials.index(self.eos) + + @property + def device(self): + return 'cuda' if torch.cuda.is_available() else 'cpu' + + def preprocess(self, data: Union[str, Iterable]) -> Iterable: + r""" + Loads a single example and tokenize it if necessary. + The sequence will be first passed to ``fn`` if available. + If ``tokenize`` is not None, the input will be tokenized. + Then the input will be lowercased optionally. + + Args: + data (Union[str, Iterable]): + The data to be preprocessed. + + Returns: + A list of preprocessed sequence. + """ + + if self.fn is not None: + data = self.fn(data) + if self.tokenize is not None: + data = self.tokenize(data) + if self.lower: + data = [str.lower(token) for token in data] + return data + + def build( + self, + dataset: Dataset, + min_freq: int = 1, + embed: Optional[Embedding] = None, + norm: Callable = None + ) -> Field: + r""" + Constructs a :class:`~supar.utils.vocab.Vocab` object for this field from the dataset. + If the vocabulary has already existed, this function will have no effect. + + Args: + dataset (Dataset): + A :class:`~supar.utils.data.Dataset` object. + One of the attributes should be named after the name of this field. + min_freq (int): + The minimum frequency needed to include a token in the vocabulary. Default: 1. + embed (Embedding): + An Embedding object, words in which will be extended to the vocabulary. Default: ``None``. + norm (Callable): + Callable function used for normalizing embedding weights. Default: ``None``. + """ + + if hasattr(self, 'vocab'): + return + + @wait + def build_vocab(dataset): + return Vocab(counter=Counter(token + for seq in progress_bar(getattr(dataset, self.name)) + for token in self.preprocess(seq)), + min_freq=min_freq, + specials=self.specials, + unk_index=self.unk_index) + self.vocab = build_vocab(dataset) + + if not embed: + self.embed = None + else: + tokens = self.preprocess(embed.tokens) + # replace the `unk` token in the pretrained with a self-defined one if existed + if embed.unk: + tokens[embed.unk_index] = self.unk + + self.vocab.update(tokens) + self.embed = torch.zeros(len(self.vocab), embed.dim) + self.embed[self.vocab[tokens]] = embed.vectors + if norm is not None: + self.embed = norm(self.embed) + return self + + def transform(self, sequences: Iterable[List[str]]) -> Iterable[torch.Tensor]: + r""" + Turns a list of sequences that use this field into tensors. + + Each sequence is first preprocessed and then numericalized if needed. + + Args: + sequences (Iterable[List[str]]): + A list of sequences. + + Returns: + A list of tensors transformed from the input sequences. + """ + + for seq in sequences: + seq = self.preprocess(seq) + if self.use_vocab: + try: + seq = [self.vocab[token] for token in seq] + except: + raise AssertionError + if self.bos: + seq = [self.bos_index] + seq + if self.delay > 0: + seq = seq + [self.pad_index for _ in range(self.delay)] + if self.eos: + seq = seq + [self.eos_index] + yield torch.tensor(seq, dtype=torch.long) + + def compose(self, batch: Iterable[torch.Tensor]) -> torch.Tensor: + r""" + Composes a batch of sequences into a padded tensor. + + Args: + batch (Iterable[~torch.Tensor]): + A list of tensors. + + Returns: + A padded tensor converted to proper device. + """ + + return pad(batch, self.pad_index).to(self.device, non_blocking=True) + + +class SubwordField(Field): + r""" + A field that conducts tokenization and numericalization over each token rather the sequence. + + This is customized for models requiring character/subword-level inputs, e.g., CharLSTM and BERT. + + Args: + fix_len (int): + A fixed length that all subword pieces will be padded to. + This is used for truncating the subword pieces exceeding the length. + To save the memory, the final length will be the smaller value + between the max length of subword pieces in a batch and `fix_len`. + + Examples: + >>> from supar.utils.tokenizer import TransformerTokenizer + >>> tokenizer = TransformerTokenizer('bert-base-cased') + >>> field = SubwordField('bert', + pad=tokenizer.pad, + unk=tokenizer.unk, + bos=tokenizer.bos, + eos=tokenizer.eos, + fix_len=20, + tokenize=tokenizer) + >>> field.vocab = tokenizer.vocab # no need to re-build the vocab + >>> next(field.transform([['This', 'field', 'performs', 'token-level', 'tokenization']])) + tensor([[ 101, 0, 0], + [ 1188, 0, 0], + [ 1768, 0, 0], + [10383, 0, 0], + [22559, 118, 1634], + [22559, 2734, 0], + [ 102, 0, 0]]) + """ + + def __init__(self, *args, **kwargs): + self.fix_len = kwargs.pop('fix_len') if 'fix_len' in kwargs else 0 + super().__init__(*args, **kwargs) + + def build( + self, + dataset: Dataset, + min_freq: int = 1, + embed: Optional[Embedding] = None, + norm: Callable = None + ) -> SubwordField: + if hasattr(self, 'vocab'): + return + + @wait + def build_vocab(dataset): + return Vocab(counter=Counter(piece + for seq in progress_bar(getattr(dataset, self.name)) + for token in seq + for piece in self.preprocess(token)), + min_freq=min_freq, + specials=self.specials, + unk_index=self.unk_index) + self.vocab = build_vocab(dataset) + + if not embed: + self.embed = None + else: + tokens = self.preprocess(embed.tokens) + # if the `unk` token has existed in the pretrained, + # then replace it with a self-defined one + if embed.unk: + tokens[embed.unk_index] = self.unk + + self.vocab.update(tokens) + self.embed = torch.zeros(len(self.vocab), embed.dim) + self.embed[self.vocab[tokens]] = embed.vectors + if norm is not None: + self.embed = norm(self.embed) + return self + + def transform(self, sequences: Iterable[List[str]]) -> Iterable[torch.Tensor]: + for seq in sequences: + seq = [self.preprocess(token) for token in seq] + if self.use_vocab: + seq = [[self.vocab[i] if i in self.vocab else self.unk_index for i in token] if token else [self.unk_index] + for token in seq] + if self.bos: + seq = [[self.bos_index]] + seq + if self.delay > 0: + seq = seq + [[self.pad_index] for _ in range(self.delay)] + if self.eos: + seq = seq + [[self.eos_index]] + if self.fix_len > 0: + seq = [ids[:self.fix_len] for ids in seq] + yield pad([torch.tensor(ids, dtype=torch.long) for ids in seq], self.pad_index) + + +class ChartField(Field): + r""" + Field dealing with chart inputs. + + Examples: + >>> chart = [[ None, 'NP', None, None, 'S*', 'S'], + [ None, None, 'VP*', None, 'VP', None], + [ None, None, None, 'VP*', 'S::VP', None], + [ None, None, None, None, 'NP', None], + [ None, None, None, None, None, 'S*'], + [ None, None, None, None, None, None]] + >>> next(field.transform([chart])) + tensor([[ -1, 37, -1, -1, 107, 79], + [ -1, -1, 120, -1, 112, -1], + [ -1, -1, -1, 120, 86, -1], + [ -1, -1, -1, -1, 37, -1], + [ -1, -1, -1, -1, -1, 107], + [ -1, -1, -1, -1, -1, -1]]) + """ + + def build( + self, + dataset: Dataset, + min_freq: int = 1 + ) -> ChartField: + @wait + def build_vocab(dataset): + return Vocab(counter=Counter(i + for chart in progress_bar(getattr(dataset, self.name)) + for row in self.preprocess(chart) + for i in row if i is not None), + min_freq=min_freq, + specials=self.specials, + unk_index=self.unk_index) + self.vocab = build_vocab(dataset) + return self + + def transform(self, charts: Iterable[List[List]]) -> Iterable[torch.Tensor]: + for chart in charts: + chart = self.preprocess(chart) + if self.use_vocab: + chart = [[self.vocab[i] if i is not None else -1 for i in row] for row in chart] + if self.bos: + chart = [[self.bos_index]*len(chart[0])] + chart + if self.eos: + chart = chart + [[self.eos_index]*len(chart[0])] + yield torch.tensor(chart, dtype=torch.long) diff --git a/tania_scripts/supar/utils/fn.py b/tania_scripts/supar/utils/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..95d826a1dea476907b520a9eff8ea6ceb6a50f95 --- /dev/null +++ b/tania_scripts/supar/utils/fn.py @@ -0,0 +1,383 @@ +# -*- coding: utf-8 -*- + +import gzip +import mmap +import os +import pickle +import shutil +import struct +import sys +import tarfile +import unicodedata +import urllib +import zipfile +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from omegaconf import DictConfig, OmegaConf +from supar.utils.common import CACHE +from supar.utils.parallel import wait + + +def ispunct(token: str) -> bool: + return all(unicodedata.category(char).startswith('P') for char in token) + + +def isfullwidth(token: str) -> bool: + return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] for char in token) + + +def islatin(token: str) -> bool: + return all('LATIN' in unicodedata.name(char) for char in token) + + +def isdigit(token: str) -> bool: + return all('DIGIT' in unicodedata.name(char) for char in token) + + +def tohalfwidth(token: str) -> str: + return unicodedata.normalize('NFKC', token) + + +def kmeans(x: List[int], k: int, max_it: int = 32) -> Tuple[List[float], List[List[int]]]: + r""" + KMeans algorithm for clustering the sentences by length. + + Args: + x (List[int]): + The list of sentence lengths. + k (int): + The number of clusters, which is an approximate value. + The final number of clusters can be less or equal to `k`. + max_it (int): + Maximum number of iterations. + If centroids does not converge after several iterations, the algorithm will be early stopped. + + Returns: + List[float], List[List[int]]: + The first list contains average lengths of sentences in each cluster. + The second is the list of clusters holding indices of data points. + + Examples: + >>> x = torch.randint(10, 20, (10,)).tolist() + >>> x + [15, 10, 17, 11, 18, 13, 17, 19, 18, 14] + >>> centroids, clusters = kmeans(x, 3) + >>> centroids + [10.5, 14.0, 17.799999237060547] + >>> clusters + [[1, 3], [0, 5, 9], [2, 4, 6, 7, 8]] + """ + + x = torch.tensor(x, dtype=torch.float) + # collect unique datapoints + datapoints, indices, freqs = x.unique(return_inverse=True, return_counts=True) + # the number of clusters must not be greater than the number of datapoints + k = min(len(datapoints), k) + # initialize k centroids randomly + centroids = datapoints[torch.randperm(len(datapoints))[:k]] + # assign each datapoint to the cluster with the closest centroid + dists, y = torch.abs_(datapoints.unsqueeze(-1) - centroids).min(-1) + + for _ in range(max_it): + # if an empty cluster is encountered, + # choose the farthest datapoint from the biggest cluster and move that the empty one + mask = torch.arange(k).unsqueeze(-1).eq(y) + none = torch.where(~mask.any(-1))[0].tolist() + for i in none: + # the biggest cluster + biggest = torch.where(mask[mask.sum(-1).argmax()])[0] + # the datapoint farthest from the centroid of the biggest cluster + farthest = dists[biggest].argmax() + # update the assigned cluster of the farthest datapoint + y[biggest[farthest]] = i + # re-calculate the mask + mask = torch.arange(k).unsqueeze(-1).eq(y) + # update the centroids + centroids, old = (datapoints * freqs * mask).sum(-1) / (freqs * mask).sum(-1), centroids + # re-assign all datapoints to clusters + dists, y = torch.abs_(datapoints.unsqueeze(-1) - centroids).min(-1) + # stop iteration early if the centroids converge + if centroids.equal(old): + break + # assign all datapoints to the new-generated clusters + # the empty ones are discarded + assigned = y.unique().tolist() + # get the centroids of the assigned clusters + centroids = centroids[assigned].tolist() + # map all values of datapoints to buckets + clusters = [torch.where(indices.unsqueeze(-1).eq(torch.where(y.eq(i))[0]).any(-1))[0].tolist() for i in assigned] + + return centroids, clusters + + +def stripe(x: torch.Tensor, n: int, w: int, offset: Tuple = (0, 0), horizontal: bool = True) -> torch.Tensor: + r""" + Returns a parallelogram stripe of the tensor. + + Args: + x (~torch.Tensor): the input tensor with 2 or more dims. + n (int): the length of the stripe. + w (int): the width of the stripe. + offset (tuple): the offset of the first two dims. + horizontal (bool): `True` if returns a horizontal stripe; `False` otherwise. + + Returns: + A parallelogram stripe of the tensor. + + Examples: + >>> x = torch.arange(25).view(5, 5) + >>> x + tensor([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + >>> stripe(x, 2, 3) + tensor([[0, 1, 2], + [6, 7, 8]]) + >>> stripe(x, 2, 3, (1, 1)) + tensor([[ 6, 7, 8], + [12, 13, 14]]) + >>> stripe(x, 2, 3, (1, 1), 0) + tensor([[ 6, 11, 16], + [12, 17, 22]]) + """ + + x = x.contiguous() + seq_len, stride = x.size(1), list(x.stride()) + numel = stride[1] + return x.as_strided(size=(n, w, *x.shape[2:]), + stride=[(seq_len + 1) * numel, (1 if horizontal else seq_len) * numel] + stride[2:], + storage_offset=(offset[0]*seq_len+offset[1])*numel) + + +def diagonal_stripe(x: torch.Tensor, offset: int = 1) -> torch.Tensor: + r""" + Returns a diagonal parallelogram stripe of the tensor. + + Args: + x (~torch.Tensor): the input tensor with 3 or more dims. + offset (int): which diagonal to consider. Default: 1. + + Returns: + A diagonal parallelogram stripe of the tensor. + + Examples: + >>> x = torch.arange(125).view(5, 5, 5) + >>> diagonal_stripe(x) + tensor([[ 5], + [36], + [67], + [98]]) + >>> diagonal_stripe(x, 2) + tensor([[10, 11], + [41, 42], + [72, 73]]) + >>> diagonal_stripe(x, -2) + tensor([[ 50, 51], + [ 81, 82], + [112, 113]]) + """ + + x = x.contiguous() + seq_len, stride = x.size(1), list(x.stride()) + n, w, numel = seq_len - abs(offset), abs(offset), stride[2] + return x.as_strided(size=(n, w, *x.shape[3:]), + stride=[((seq_len + 1) * x.size(2) + 1) * numel] + stride[2:], + storage_offset=offset*stride[1] if offset > 0 else abs(offset)*stride[0]) + + +def expanded_stripe(x: torch.Tensor, n: int, w: int, offset: Tuple = (0, 0)) -> torch.Tensor: + r""" + Returns an expanded parallelogram stripe of the tensor. + + Args: + x (~torch.Tensor): the input tensor with 2 or more dims. + n (int): the length of the stripe. + w (int): the width of the stripe. + offset (tuple): the offset of the first two dims. + + Returns: + An expanded parallelogram stripe of the tensor. + + Examples: + >>> x = torch.arange(25).view(5, 5) + >>> x + tensor([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + >>> expanded_stripe(x, 2, 3) + tensor([[[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14]], + + [[ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]]]) + >>> expanded_stripe(x, 2, 3, (1, 1)) + tensor([[[ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]], + + [[10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]]) + + """ + x = x.contiguous() + stride = list(x.stride()) + return x.as_strided(size=(n, w, *list(x.shape[1:])), + stride=stride[:1] + [stride[0]] + stride[1:], + storage_offset=(offset[1])*stride[0]) + + +def pad( + tensors: List[torch.Tensor], + padding_value: int = 0, + total_length: int = None, + padding_side: str = 'right' +) -> torch.Tensor: + size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors) + for i in range(len(tensors[0].size()))] + if total_length is not None: + assert total_length >= size[1] + size[1] = total_length + out_tensor = tensors[0].data.new(*size).fill_(padding_value) + for i, tensor in enumerate(tensors): + out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor + return out_tensor + + +@wait +def download(url: str, path: Optional[str] = None, reload: bool = False, clean: bool = False) -> str: + filename = os.path.basename(urllib.parse.urlparse(url).path) + if path is None: + path = CACHE + os.makedirs(path, exist_ok=True) + path = os.path.join(path, filename) + if reload and os.path.exists(path): + os.remove(path) + if not os.path.exists(path): + sys.stderr.write(f"Downloading {url} to {path}\n") + try: + torch.hub.download_url_to_file(url, path, progress=True) + except (ValueError, urllib.error.URLError): + raise RuntimeError(f"File {url} unavailable. Please try other sources.") + return extract(path, reload, clean) + + +def extract(path: str, reload: bool = False, clean: bool = False) -> str: + extracted = path + if zipfile.is_zipfile(path): + with zipfile.ZipFile(path) as f: + extracted = os.path.join(os.path.dirname(path), f.infolist()[0].filename) + if reload or not os.path.exists(extracted): + f.extractall(os.path.dirname(path)) + elif tarfile.is_tarfile(path): + with tarfile.open(path) as f: + extracted = os.path.join(os.path.dirname(path), f.getnames()[0]) + if reload or not os.path.exists(extracted): + f.extractall(os.path.dirname(path)) + elif path.endswith('.gz'): + extracted = path[:-3] + with gzip.open(path) as fgz: + with open(extracted, 'wb') as f: + shutil.copyfileobj(fgz, f) + if clean: + os.remove(path) + return extracted + + +def binarize( + data: Union[List[str], Dict[str, Iterable]], + fbin: str = None, + merge: bool = False +) -> Tuple[str, torch.Tensor]: + start, meta = 0, defaultdict(list) + # the binarized file is organized as: + # `data`: pickled objects + # `meta`: a dict containing the pointers of each kind of data + # `index`: fixed size integers representing the storage positions of the meta data + with open(fbin, 'wb') as f: + # in this case, data should be a list of binarized files + if merge: + for file in data: + if not os.path.exists(file): + raise RuntimeError("Some files are missing. Please check the paths") + mi = debinarize(file, meta=True) + for key, val in mi.items(): + val[:, 0] += start + meta[key].append(val) + with open(file, 'rb') as fi: + length = int(sum(val[:, 1].sum() for val in mi.values())) + f.write(fi.read(length)) + start = start + length + meta = {key: torch.cat(val) for key, val in meta.items()} + else: + for key, val in data.items(): + for i in val: + bytes = pickle.dumps(i) + f.write(bytes) + meta[key].append((start, len(bytes))) + start = start + len(bytes) + meta = {key: torch.tensor(val) for key, val in meta.items()} + pickled = pickle.dumps(meta) + # append the meta data to the end of the bin file + f.write(pickled) + # record the positions of the meta data + f.write(struct.pack('LL', start, len(pickled))) + return fbin, meta + + +def debinarize( + fbin: str, + pos_or_key: Optional[Union[Tuple[int, int], str]] = (0, 0), + meta: bool = False +) -> Union[Any, Iterable[Any]]: + with open(fbin, 'rb') as f, mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + if meta or isinstance(pos_or_key, str): + length = len(struct.pack('LL', 0, 0)) + mm.seek(-length, os.SEEK_END) + offset, length = struct.unpack('LL', mm.read(length)) + mm.seek(offset) + if meta: + return pickle.loads(mm.read(length)) + # fetch by key + objs, meta = [], pickle.loads(mm.read(length))[pos_or_key] + for offset, length in meta.tolist(): + mm.seek(offset) + objs.append(pickle.loads(mm.read(length))) + return objs + # fetch by positions + offset, length = pos_or_key + mm.seek(offset) + return pickle.loads(mm.read(length)) + + +def resolve_config(args: Union[Dict, DictConfig]) -> DictConfig: + OmegaConf.register_new_resolver("eval", eval) + return DictConfig(OmegaConf.to_container(args, resolve=True)) + + +def collect_args(args: Union[Dict, DictConfig]) -> DictConfig: + for key in ('self', 'cls', '__class__'): + args.pop(key, None) + args.update(args.pop('kwargs', dict())) + return DictConfig(args) + + +def get_rng_state() -> Dict[str, torch.Tensor]: + state = {'rng_state': torch.get_rng_state()} + if torch.cuda.is_available(): + state['cuda_rng_state'] = torch.cuda.get_rng_state() + return state + + +def set_rng_state(state: Dict) -> None: + torch.set_rng_state(state['rng_state']) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state['cuda_rng_state']) diff --git a/tania_scripts/supar/utils/logging.py b/tania_scripts/supar/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..73e4a77a0ca5f660083ae74ff0c8377efeb5a280 --- /dev/null +++ b/tania_scripts/supar/utils/logging.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- + +import logging +import os +from logging import FileHandler, Formatter, Handler, Logger, StreamHandler +from typing import Iterable, Optional + +from supar.utils.parallel import is_master +from tqdm import tqdm + + +def get_logger(name: Optional[str] = None) -> Logger: + logger = logging.getLogger(name) + # init the root logger + if name is None: + logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[TqdmHandler()]) + return logger + + +class TqdmHandler(StreamHandler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + self.handleError(record) + + +def init_logger( + logger: Logger, + path: Optional[str] = None, + mode: str = 'w', + handlers: Optional[Iterable[Handler]] = None, + verbose: bool = True +) -> Logger: + if not handlers: + if path: + os.makedirs(os.path.dirname(path) or './', exist_ok=True) + logger.addHandler(FileHandler(path, mode)) + for handler in logger.handlers: + handler.setFormatter(ColoredFormatter(colored=not isinstance(handler, FileHandler))) + logger.setLevel(logging.INFO if is_master() and verbose else logging.WARNING) + return logger + + +def progress_bar( + iterator: Iterable, + ncols: Optional[int] = None, + bar_format: str = '{l_bar}{bar:20}| {n_fmt}/{total_fmt} {elapsed}<{remaining}, {rate_fmt}{postfix}', + leave: bool = False, + **kwargs +) -> tqdm: + return tqdm(iterator, + ncols=ncols, + bar_format=bar_format, + ascii=True, + disable=(not (logger.level == logging.INFO and is_master())), + leave=leave, + **kwargs) + + +class ColoredFormatter(Formatter): + + BLACK = '\033[30m' + RED = '\033[31m' + GREEN = '\033[32m' + GREY = '\033[37m' + RESET = '\033[0m' + + COLORS = { + logging.ERROR: RED, + logging.WARNING: RED, + logging.INFO: GREEN, + logging.DEBUG: BLACK, + logging.NOTSET: BLACK + } + + def __init__(self, colored=True, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.colored = colored + + def format(self, record): + fmt = '[%(asctime)s %(levelname)s] %(message)s' + if self.colored: + fmt = f'{self.COLORS[record.levelno]}[%(asctime)s %(levelname)s]{self.RESET} %(message)s' + datefmt = '%Y-%m-%d %H:%M:%S' + return Formatter(fmt=fmt, datefmt=datefmt).format(record) + + +logger = get_logger() diff --git a/tania_scripts/supar/utils/metric.py b/tania_scripts/supar/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..f64940c1a909006d264c86976242ea6ff931affe --- /dev/null +++ b/tania_scripts/supar/utils/metric.py @@ -0,0 +1,346 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from collections import Counter +from typing import Dict, List, Optional, Tuple + +import torch + + +class Metric(object): + + def __init__(self, reverse: Optional[bool] = None, eps: float = 1e-12) -> Metric: + super().__init__() + + self.n = 0.0 + self.count = 0.0 + self.total_loss = 0.0 + self.reverse = reverse + self.eps = eps + + def __repr__(self): + return f"loss: {self.loss:.4f} - " + ' '.join([f"{key}: {val:6.2%}" for key, val in self.values.items()]) + + def __lt__(self, other: Metric) -> bool: + if not hasattr(self, 'score'): + return True + if not hasattr(other, 'score'): + return False + return (self.score < other.score) if not self.reverse else (self.score > other.score) + + def __le__(self, other: Metric) -> bool: + if not hasattr(self, 'score'): + return True + if not hasattr(other, 'score'): + return False + return (self.score <= other.score) if not self.reverse else (self.score >= other.score) + + def __gt__(self, other: Metric) -> bool: + if not hasattr(self, 'score'): + return False + if not hasattr(other, 'score'): + return True + return (self.score > other.score) if not self.reverse else (self.score < other.score) + + def __ge__(self, other: Metric) -> bool: + if not hasattr(self, 'score'): + return False + if not hasattr(other, 'score'): + return True + return (self.score >= other.score) if not self.reverse else (self.score <= other.score) + + def __add__(self, other: Metric) -> Metric: + return other + + @property + def score(self): + raise AttributeError + + @property + def loss(self): + return self.total_loss / (self.count + self.eps) + + @property + def values(self): + raise AttributeError + + +class AttachmentMetric(Metric): + + def __init__( + self, + loss: Optional[float] = None, + preds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + golds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + mask: Optional[torch.BoolTensor] = None, + reverse: bool = False, + eps: float = 1e-12 + ) -> AttachmentMetric: + super().__init__(reverse=reverse, eps=eps) + + self.n_ucm = 0.0 + self.n_lcm = 0.0 + self.total = 0.0 + self.correct_arcs = 0.0 + self.correct_rels = 0.0 + + if loss is not None: + self(loss, preds, golds, mask) + + def __call__( + self, + loss: float, + preds: Tuple[torch.Tensor, torch.Tensor], + golds: Tuple[torch.Tensor, torch.Tensor], + mask: torch.BoolTensor + ) -> AttachmentMetric: + lens = mask.sum(1) + arc_preds, rel_preds, arc_golds, rel_golds = *preds, *golds + arc_mask = arc_preds.eq(arc_golds) & mask + rel_mask = rel_preds.eq(rel_golds) & arc_mask + arc_mask_seq, rel_mask_seq = arc_mask[mask], rel_mask[mask] + + self.n += len(mask) + self.count += 1 + self.total_loss += float(loss) + self.n_ucm += arc_mask.sum(1).eq(lens).sum().item() + self.n_lcm += rel_mask.sum(1).eq(lens).sum().item() + + self.total += len(arc_mask_seq) + self.correct_arcs += arc_mask_seq.sum().item() + self.correct_rels += rel_mask_seq.sum().item() + return self + + def __add__(self, other: AttachmentMetric) -> AttachmentMetric: + metric = AttachmentMetric(eps=self.eps) + metric.n = self.n + other.n + metric.count = self.count + other.count + metric.total_loss = self.total_loss + other.total_loss + metric.n_ucm = self.n_ucm + other.n_ucm + metric.n_lcm = self.n_lcm + other.n_lcm + metric.total = self.total + other.total + metric.correct_arcs = self.correct_arcs + other.correct_arcs + metric.correct_rels = self.correct_rels + other.correct_rels + metric.reverse = self.reverse or other.reverse + return metric + + @property + def score(self): + return self.las + + @property + def ucm(self): + return self.n_ucm / (self.n + self.eps) + + @property + def lcm(self): + return self.n_lcm / (self.n + self.eps) + + @property + def uas(self): + return self.correct_arcs / (self.total + self.eps) + + @property + def las(self): + return self.correct_rels / (self.total + self.eps) + + @property + def values(self) -> Dict: + return {'UCM': self.ucm, + 'LCM': self.lcm, + 'UAS': self.uas, + 'LAS': self.las} + + +class SpanMetric(Metric): + + def __init__( + self, + loss: Optional[float] = None, + preds: Optional[List[List[Tuple]]] = None, + golds: Optional[List[List[Tuple]]] = None, + reverse: bool = False, + eps: float = 1e-12 + ) -> SpanMetric: + super().__init__(reverse=reverse, eps=eps) + + self.n_ucm = 0.0 + self.n_lcm = 0.0 + self.utp = 0.0 + self.ltp = 0.0 + self.pred = 0.0 + self.gold = 0.0 + + if loss is not None: + self(loss, preds, golds) + + def __call__( + self, + loss: float, + preds: List[List[Tuple]], + golds: List[List[Tuple]] + ) -> SpanMetric: + self.n += len(preds) + self.count += 1 + self.total_loss += float(loss) + for pred, gold in zip(preds, golds): + upred, ugold = Counter([tuple(span[:-1]) for span in pred]), Counter([tuple(span[:-1]) for span in gold]) + lpred, lgold = Counter([tuple(span) for span in pred]), Counter([tuple(span) for span in gold]) + utp, ltp = list((upred & ugold).elements()), list((lpred & lgold).elements()) + self.n_ucm += len(utp) == len(pred) == len(gold) + self.n_lcm += len(ltp) == len(pred) == len(gold) + self.utp += len(utp) + self.ltp += len(ltp) + self.pred += len(pred) + self.gold += len(gold) + return self + + def __add__(self, other: SpanMetric) -> SpanMetric: + metric = SpanMetric(eps=self.eps) + metric.n = self.n + other.n + metric.count = self.count + other.count + metric.total_loss = self.total_loss + other.total_loss + metric.n_ucm = self.n_ucm + other.n_ucm + metric.n_lcm = self.n_lcm + other.n_lcm + metric.utp = self.utp + other.utp + metric.ltp = self.ltp + other.ltp + metric.pred = self.pred + other.pred + metric.gold = self.gold + other.gold + metric.reverse = self.reverse or other.reverse + return metric + + @property + def score(self): + return self.lf + + @property + def ucm(self): + return self.n_ucm / (self.n + self.eps) + + @property + def lcm(self): + return self.n_lcm / (self.n + self.eps) + + @property + def up(self): + return self.utp / (self.pred + self.eps) + + @property + def ur(self): + return self.utp / (self.gold + self.eps) + + @property + def uf(self): + return 2 * self.utp / (self.pred + self.gold + self.eps) + + @property + def lp(self): + return self.ltp / (self.pred + self.eps) + + @property + def lr(self): + return self.ltp / (self.gold + self.eps) + + @property + def lf(self): + return 2 * self.ltp / (self.pred + self.gold + self.eps) + + @property + def values(self) -> Dict: + return {'UCM': self.ucm, + 'LCM': self.lcm, + 'UP': self.up, + 'UR': self.ur, + 'UF': self.uf, + 'LP': self.lp, + 'LR': self.lr, + 'LF': self.lf} + + +class ChartMetric(Metric): + + def __init__( + self, + loss: Optional[float] = None, + preds: Optional[torch.Tensor] = None, + golds: Optional[torch.Tensor] = None, + reverse: bool = False, + eps: float = 1e-12 + ) -> ChartMetric: + super().__init__(reverse=reverse, eps=eps) + + self.tp = 0.0 + self.utp = 0.0 + self.pred = 0.0 + self.gold = 0.0 + + if loss is not None: + self(loss, preds, golds) + + def __call__( + self, + loss: float, + preds: torch.Tensor, + golds: torch.Tensor + ) -> ChartMetric: + self.n += len(preds) + self.count += 1 + self.total_loss += float(loss) + pred_mask = preds.ge(0) + gold_mask = golds.ge(0) + span_mask = pred_mask & gold_mask + self.pred += pred_mask.sum().item() + self.gold += gold_mask.sum().item() + self.tp += (preds.eq(golds) & span_mask).sum().item() + self.utp += span_mask.sum().item() + return self + + def __add__(self, other: ChartMetric) -> ChartMetric: + metric = ChartMetric(eps=self.eps) + metric.n = self.n + other.n + metric.count = self.count + other.count + metric.total_loss = self.total_loss + other.total_loss + metric.tp = self.tp + other.tp + metric.utp = self.utp + other.utp + metric.pred = self.pred + other.pred + metric.gold = self.gold + other.gold + metric.reverse = self.reverse or other.reverse + return metric + + @property + def score(self): + return self.f + + @property + def up(self): + return self.utp / (self.pred + self.eps) + + @property + def ur(self): + return self.utp / (self.gold + self.eps) + + @property + def uf(self): + return 2 * self.utp / (self.pred + self.gold + self.eps) + + @property + def p(self): + return self.tp / (self.pred + self.eps) + + @property + def r(self): + return self.tp / (self.gold + self.eps) + + @property + def f(self): + return 2 * self.tp / (self.pred + self.gold + self.eps) + + @property + def values(self) -> Dict: + return {'UP': self.up, + 'UR': self.ur, + 'UF': self.uf, + 'P': self.p, + 'R': self.r, + 'F': self.f} diff --git a/tania_scripts/supar/utils/optim.py b/tania_scripts/supar/utils/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..e67730b4eacbf68008f6b955cb0fa7bea2f33bc1 --- /dev/null +++ b/tania_scripts/supar/utils/optim.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + + +class InverseSquareRootLR(_LRScheduler): + + def __init__( + self, + optimizer: Optimizer, + warmup_steps: int, + last_epoch: int = -1 + ) -> InverseSquareRootLR: + self.warmup_steps = warmup_steps + self.factor = warmup_steps ** 0.5 + super(InverseSquareRootLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + epoch = max(self.last_epoch, 1) + scale = min(epoch ** -0.5, epoch * self.warmup_steps ** -1.5) * self.factor + return [scale * lr for lr in self.base_lrs] + + +class PolynomialLR(_LRScheduler): + r""" + Set the learning rate for each parameter group using a polynomial defined as: `lr = base_lr * (1 - t / T) ^ (power)`, + where `t` is the current epoch and `T` is the maximum number of epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + warmup_steps: int = 0, + steps: int = 100000, + power: float = 1., + last_epoch: int = -1 + ) -> PolynomialLR: + self.warmup_steps = warmup_steps + self.steps = steps + self.power = power + super(PolynomialLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + epoch = max(self.last_epoch, 1) + if epoch <= self.warmup_steps: + return [epoch / self.warmup_steps * lr for lr in self.base_lrs] + t, T = (epoch - self.warmup_steps), (self.steps - self.warmup_steps) + return [lr * (1 - t / T) ** self.power for lr in self.base_lrs] + + +def LinearLR(optimizer: Optimizer, warmup_steps: int = 0, steps: int = 100000, last_epoch: int = -1) -> PolynomialLR: + return PolynomialLR(optimizer, warmup_steps, steps, 1, last_epoch) diff --git a/tania_scripts/supar/utils/parallel.py b/tania_scripts/supar/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f1e0b0cee6b98d07bbec7e2193fcd5bcb97343 --- /dev/null +++ b/tania_scripts/supar/utils/parallel.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import functools +import os +import re +from typing import Any, Iterable + +import torch +import torch.distributed as dist +import torch.nn as nn + + +class DistributedDataParallel(nn.parallel.DistributedDataParallel): + + def __init__(self, module, **kwargs): + super().__init__(module, **kwargs) + + def __getattr__(self, name): + wrapped = super().__getattr__('module') + if hasattr(wrapped, name): + return getattr(wrapped, name) + return super().__getattr__(name) + + +def wait(fn) -> Any: + @functools.wraps(fn) + def wrapper(*args, **kwargs): + value = None + if is_master(): + value = fn(*args, **kwargs) + if is_dist(): + dist.barrier() + value = gather(value)[0] + return value + return wrapper + + +def gather(obj: Any) -> Iterable[Any]: + objs = [None] * dist.get_world_size() + dist.all_gather_object(objs, obj) + return objs + + +def reduce(obj: Any, reduction: str = 'sum') -> Any: + objs = gather(obj) + if reduction == 'sum': + return functools.reduce(lambda x, y: x + y, objs) + elif reduction == 'mean': + return functools.reduce(lambda x, y: x + y, objs) / len(objs) + elif reduction == 'min': + return min(objs) + elif reduction == 'max': + return max(objs) + else: + raise NotImplementedError(f"Unsupported reduction {reduction}") + + +def is_dist(): + return dist.is_available() and dist.is_initialized() + + +def is_master(): + return not is_dist() or dist.get_rank() == 0 + + +def get_free_port(): + import socket + s = socket.socket() + s.bind(('', 0)) + port = str(s.getsockname()[1]) + s.close() + return port + + +def get_device_count(): + if 'CUDA_VISIBLE_DEVICES' in os.environ: + return len(re.findall(r'\d+', os.environ['CUDA_VISIBLE_DEVICES'])) + return torch.cuda.device_count() diff --git a/tania_scripts/supar/utils/tokenizer.py b/tania_scripts/supar/utils/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..35920b63a672c8b2cb4283355c779313bb5fbebb --- /dev/null +++ b/tania_scripts/supar/utils/tokenizer.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +import re +import tempfile +from collections import Counter, defaultdict +from typing import Any, Dict, List, Optional, Union + +import torch.distributed as dist +from supar.utils.parallel import is_dist, is_master +from supar.utils.vocab import Vocab +from torch.distributions.utils import lazy_property + + +class Tokenizer: + + def __init__(self, lang: str = 'en') -> Tokenizer: + import stanza + try: + self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', verbose=False, tokenize_no_ssplit=True) + except Exception: + stanza.download(lang=lang, resources_url='stanford') + self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', verbose=False, tokenize_no_ssplit=True) + + def __call__(self, text: str) -> List[str]: + return [i.text for i in self.pipeline(text).sentences[0].tokens] + + +class TransformerTokenizer: + + def __init__(self, name) -> TransformerTokenizer: + from transformers import AutoTokenizer + self.name = name + try: + self.tokenizer = AutoTokenizer.from_pretrained(name, local_files_only=True) + except Exception: + self.tokenizer = AutoTokenizer.from_pretrained(name, local_files_only=False) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.name})" + + def __len__(self) -> int: + return self.vocab_size + + def __call__(self, text: str) -> List[str]: + from tokenizers.pre_tokenizers import ByteLevel + if isinstance(self.tokenizer.backend_tokenizer.pre_tokenizer, ByteLevel): + text = ' ' + text + return self.tokenizer.tokenize(text) + + def __getattr__(self, name: str) -> Any: + return getattr(self.tokenizer, name) + + def __getstate__(self) -> Dict: + return self.__dict__ + + def __setstate__(self, state: Dict): + self.__dict__.update(state) + + @lazy_property + def vocab(self): + return defaultdict(lambda: self.tokenizer.vocab[self.unk], self.tokenizer.get_vocab()) + + @lazy_property + def tokens(self): + return sorted(self.vocab, key=lambda x: self.vocab[x]) + + @property + def vocab_size(self): + return len(self.vocab) + + @property + def pad(self): + return self.tokenizer.pad_token + + @property + def unk(self): + return self.tokenizer.unk_token + + @property + def bos(self): + return self.tokenizer.bos_token or self.tokenizer.cls_token + + @property + def eos(self): + return self.tokenizer.eos_token or self.tokenizer.sep_token + + def decode(self, text: List) -> str: + return self.tokenizer.decode(text, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + +class BPETokenizer: + + def __init__( + self, + path: str = None, + files: Optional[List[str]] = None, + vocab_size: Optional[int] = 32000, + min_freq: Optional[int] = 2, + dropout: float = None, + backend: str = 'huggingface', + pad: Optional[str] = None, + unk: Optional[str] = None, + bos: Optional[str] = None, + eos: Optional[str] = None, + ) -> BPETokenizer: + + self.path = path + self.files = files + self.min_freq = min_freq + self.dropout = dropout or .0 + self.backend = backend + self.pad = pad + self.unk = unk + self.bos = bos + self.eos = eos + self.special_tokens = [i for i in [pad, unk, bos, eos] if i is not None] + + if backend == 'huggingface': + from tokenizers import Tokenizer + from tokenizers.decoders import BPEDecoder + from tokenizers.models import BPE + from tokenizers.pre_tokenizers import WhitespaceSplit + from tokenizers.trainers import BpeTrainer + path = os.path.join(path, 'tokenizer.json') + if is_master() and not os.path.exists(path): + # start to train a tokenizer from scratch + self.tokenizer = Tokenizer(BPE(dropout=dropout, unk_token=unk)) + self.tokenizer.pre_tokenizer = WhitespaceSplit() + self.tokenizer.decoder = BPEDecoder() + self.tokenizer.train(files=files, + trainer=BpeTrainer(vocab_size=vocab_size, + min_frequency=min_freq, + special_tokens=self.special_tokens, + end_of_word_suffix='</w>')) + self.tokenizer.save(path) + if is_dist(): + dist.barrier() + self.tokenizer = Tokenizer.from_file(path) + self.vocab = self.tokenizer.get_vocab() + + elif backend == 'subword-nmt': + import argparse + from argparse import Namespace + + from subword_nmt.apply_bpe import BPE, read_vocabulary + from subword_nmt.learn_joint_bpe_and_vocab import learn_joint_bpe_and_vocab + fmerge = os.path.join(path, 'merge.txt') + fvocab = os.path.join(path, 'vocab.txt') + separator = '@@' + if is_master() and (not os.path.exists(fmerge) or not os.path.exists(fvocab)): + with tempfile.TemporaryDirectory() as ftemp: + fall = os.path.join(ftemp, 'fall') + with open(fall, 'w') as f: + for file in files: + with open(file) as fi: + f.write(fi.read()) + learn_joint_bpe_and_vocab(Namespace(input=[argparse.FileType()(fall)], + output=argparse.FileType('w')(fmerge), + symbols=vocab_size, + separator=separator, + vocab=[argparse.FileType('w')(fvocab)], + min_frequency=min_freq, + total_symbols=False, + verbose=False, + num_workers=32)) + if is_dist(): + dist.barrier() + self.tokenizer = BPE(codes=open(fmerge), separator=separator, vocab=read_vocabulary(open(fvocab), None)) + self.vocab = Vocab(counter=Counter(self.tokenizer.vocab), + specials=self.special_tokens, + unk_index=self.special_tokens.index(unk)) + else: + raise ValueError(f'Unsupported backend: {backend} not in (huggingface, subword-nmt)') + + def __repr__(self) -> str: + s = self.__class__.__name__ + f'({self.vocab_size}, min_freq={self.min_freq}' + if self.dropout > 0: + s += f", dropout={self.dropout}" + s += f", backend={self.backend}" + if self.pad is not None: + s += f", pad={self.pad}" + if self.unk is not None: + s += f", unk={self.unk}" + if self.bos is not None: + s += f", bos={self.bos}" + if self.eos is not None: + s += f", eos={self.eos}" + s += ')' + return s + + def __len__(self) -> int: + return self.vocab_size + + def __call__(self, text: Union[str, List]) -> List[str]: + is_pretokenized = isinstance(text, list) + if self.backend == 'huggingface': + return self.tokenizer.encode(text, is_pretokenized=is_pretokenized).tokens + else: + if not is_pretokenized: + text = text.split() + return self.tokenizer.segment_tokens(text, dropout=self.dropout) + + @lazy_property + def tokens(self): + return sorted(self.vocab, key=lambda x: self.vocab[x]) + + @property + def vocab_size(self): + return len(self.vocab) + + def decode(self, text: List) -> str: + if self.backend == 'huggingface': + return self.tokenizer.decode(text) + else: + text = self.vocab[text] + text = ' '.join([i for i in text if i not in self.special_tokens]) + return re.sub(f'({self.tokenizer.separator} )|({self.tokenizer.separator} ?$)', '', text) diff --git a/tania_scripts/supar/utils/transform.py b/tania_scripts/supar/utils/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d070c7ae27942474a8004b7688d7033582ba46 --- /dev/null +++ b/tania_scripts/supar/utils/transform.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Any, Iterable, Optional, Tuple + +import torch +from torch.distributions.utils import lazy_property + +from supar.utils.fn import debinarize +from supar.utils.logging import get_logger, progress_bar + +logger = get_logger(__name__) + + +class Transform(object): + r""" + A :class:`Transform` object corresponds to a specific data format, which holds several instances of data fields + that provide instructions for preprocessing and numericalization, etc. + + Attributes: + training (bool): + Sets the object in training mode. + If ``False``, some data fields not required for predictions won't be returned. + Default: ``True``. + """ + + fields = [] + + def __init__(self): + self.training = True + + def __len__(self): + return len(self.fields) + + def __repr__(self): + s = '\n' + '\n'.join([f" {f}" for f in self.flattened_fields]) + '\n' + return f"{self.__class__.__name__}({s})" + + def __call__(self, sentences: Iterable[Sentence]) -> Iterable[Sentence]: + return [sentence.numericalize(self.flattened_fields) for sentence in progress_bar(sentences)] + + def __getitem__(self, index): + return getattr(self, self.fields[index]) + + @property + def flattened_fields(self): + flattened = [] + for field in self: + if field not in self.src and field not in self.tgt: + continue + if not self.training and field in self.tgt: + continue + if not isinstance(field, Iterable): + field = [field] + for f in field: + if f is not None: + flattened.append(f) + return flattened + + def train(self, training=True): + self.training = training + + def eval(self): + self.train(False) + + def append(self, field): + self.fields.append(field.name) + setattr(self, field.name, field) + + @property + def src(self): + raise AttributeError + + @property + def tgt(self): + raise AttributeError + + +class Batch(object): + + def __init__(self, sentences: Iterable[Sentence]) -> Batch: + self.sentences = sentences + + self.names, self.fields = [], {} + + def __repr__(self): + return f'{self.__class__.__name__}({", ".join([f"{name}" for name in self.names])})' + + def __len__(self): + return len(self.sentences) + + def __getitem__(self, index): + return self.fields[self.names[index]] + + def __getattr__(self, name): + return [s.fields[name] for s in self.sentences] + + def __setattr__(self, name: str, value: Iterable[Any]): + if name not in ('sentences', 'fields', 'names'): + for s, v in zip(self.sentences, value): + setattr(s, name, v) + else: + self.__dict__[name] = value + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + + @property + def device(self): + return 'cuda' if torch.cuda.is_available() else 'cpu' + + @lazy_property + def lens(self): + return torch.tensor([len(i) for i in self.sentences]).to(self.device, non_blocking=True) + + @lazy_property + def mask(self): + return self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(self.lens.max()))) + + def compose(self, transform: Transform) -> Batch: + for f in transform.flattened_fields: + self.names.append(f.name) + self.fields[f.name] = f.compose([s.fields[f.name] for s in self.sentences]) + return self + + def shrink(self, batch_size: Optional[int] = None) -> Batch: + if batch_size is None: + batch_size = len(self) // 2 + if batch_size <= 0: + raise RuntimeError(f"The batch has only {len(self)} sentences and can't be shrinked!") + return Batch([self.sentences[i] for i in torch.randperm(len(self))[:batch_size].tolist()]) + + def pin_memory(self): + for s in self.sentences: + for i in s.fields.values(): + if isinstance(i, torch.Tensor): + i.pin_memory() + return self + + +class Sentence(object): + + def __init__(self, transform, index: Optional[int] = None) -> Sentence: + self.index = index + # mapping from each nested field to their proper position + self.maps = dict() + # original values and numericalized values of each position + self.values, self.fields = [], {} + for i, field in enumerate(transform): + + if not isinstance(field, Iterable): + field = [field] + for f in field: + if f is not None: + self.maps[f.name] = i + self.fields[f.name] = None + + def __contains__(self, name): + return name in self.fields + + def __getattr__(self, name): + if name in self.fields: + return self.values[self.maps[name]] + raise AttributeError(f"`{name}` not found") + + def __setattr__(self, name, value): + if 'fields' in self.__dict__ and name in self: + index = self.maps[name] + if index >= len(self.values): + self.__dict__[name] = value + else: + self.values[index] = value + else: + self.__dict__[name] = value + + def __getstate__(self): + state = vars(self) + if 'fields' in state: + state['fields'] = { + name: ((value.dtype, value.tolist()) + if isinstance(value, torch.Tensor) + else value) + for name, value in state['fields'].items() + } + return state + + def __setstate__(self, state): + if 'fields' in state: + state['fields'] = { + name: (torch.tensor(value[1], dtype=value[0]) + if isinstance(value, tuple) and isinstance(value[0], torch.dtype) + else value) + for name, value in state['fields'].items() + } + self.__dict__.update(state) + + def __len__(self): + try: + return len(next(iter(self.fields.values()))) + except Exception: + raise AttributeError("Cannot get size of a sentence with no fields") + + @lazy_property + def size(self): + return len(self) + + def numericalize(self, fields): + for f in fields: + self.fields[f.name] = next(f.transform([getattr(self, f.name)])) + self.pad_index = fields[0].pad_index + return self + + @classmethod + def from_cache(cls, fbin: str, pos: Tuple[int, int]) -> Sentence: + return debinarize(fbin, pos) diff --git a/tania_scripts/supar/utils/vocab.py b/tania_scripts/supar/utils/vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..ee9cae0a04a30bf92d1799525f6bfd73b36a4046 --- /dev/null +++ b/tania_scripts/supar/utils/vocab.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from collections import Counter, defaultdict +from typing import Iterable, Tuple, Union + + +class Vocab(object): + r""" + Defines a vocabulary object that will be used to numericalize a field. + + Args: + counter (~collections.Counter): + :class:`~collections.Counter` object holding the frequencies of each value found in the data. + min_freq (int): + The minimum frequency needed to include a token in the vocabulary. Default: 1. + specials (Tuple[str]): + The list of special tokens (e.g., pad, unk, bos and eos) that will be prepended to the vocabulary. Default: ``[]``. + unk_index (int): + The index of unk token. Default: 0. + + Attributes: + itos: + A list of token strings indexed by their numerical identifiers. + stoi: + A :class:`~collections.defaultdict` object mapping token strings to numerical identifiers. + """ + + def __init__(self, counter: Counter, min_freq: int = 1, specials: Tuple = tuple(), unk_index: int = 0) -> Vocab: + self.itos = list(specials) + self.stoi = defaultdict(lambda: unk_index) + self.stoi.update({token: i for i, token in enumerate(self.itos)}) + self.update([token for token, freq in counter.items() if freq >= min_freq]) + self.unk_index = unk_index + self.n_init = len(self) + + def __len__(self): + return len(self.itos) + + def __getitem__(self, key: Union[int, str, Iterable]) -> Union[str, int, Iterable]: + if isinstance(key, str): + return self.stoi[key] + elif not isinstance(key, Iterable): + return self.itos[key] + elif isinstance(key[0], str): + return [self.stoi[i] for i in key] + else: + return [self.itos[i] for i in key] + + def __contains__(self, token): + return token in self.stoi + + def __getstate__(self): + # avoid picking defaultdict + attrs = dict(self.__dict__) + # cast to regular dict + attrs['stoi'] = dict(self.stoi) + return attrs + + def __setstate__(self, state): + stoi = defaultdict(lambda: self.unk_index) + stoi.update(state['stoi']) + state['stoi'] = stoi + self.__dict__.update(state) + + def items(self): + return self.stoi.items() + + def update(self, vocab: Union[Iterable[str], Vocab, Counter]) -> Vocab: + if isinstance(vocab, Vocab): + vocab = vocab.itos + # NOTE: PAY CAREFUL ATTENTION TO DICT ORDER UNDER DISTRIBUTED TRAINING! + vocab = sorted(set(vocab).difference(self.stoi)) + self.itos.extend(vocab) + self.stoi.update({token: i for i, token in enumerate(vocab, len(self.stoi))}) + return self diff --git a/tania_scripts/supar/vector_quantize.py b/tania_scripts/supar/vector_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9a6ac374ae3c04ed5ad738aca34c909f31eaa6 --- /dev/null +++ b/tania_scripts/supar/vector_quantize.py @@ -0,0 +1,130 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from clusopt_core.cluster import Streamkm + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories, eps=1e-5): + return (x + eps) / (x.sum() + n_categories * eps) + + +class VectorQuantize(nn.Module): + # Based on: https://github.com/lucidrains/vector-quantize-pytorch + def __init__( + self, + dim, + n_embed, + decay=0.8, + commitment=1.0, + eps=1e-5, + wait_steps=0, + observe_steps=1245, + coreset_size_multiplier=10, + ): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + self.commitment = commitment + + embed = torch.randn(dim, n_embed) + self.register_buffer("embed", nn.Parameter(embed)) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", nn.Parameter(embed.clone())) + + self.wait_steps_remaining = wait_steps + self.observe_steps_remaining = observe_steps + self.clustering_model = Streamkm( + coresetsize=n_embed * coreset_size_multiplier, + length=1500000, + seed=42, + ) + self.data_chunks = [] + + def stream_cluster(self, input, expected_num_tokens=None): + input = input.reshape(-1, self.dim) + input_np = input.detach().cpu().numpy() + assert len(input.shape) == 2 + self.data_chunks.append(input_np) + if ( + expected_num_tokens is not None + and sum([chunk.shape[0] for chunk in self.data_chunks]) + < expected_num_tokens + ): + return # This is not the last sub-batch. + if self.wait_steps_remaining > 0: + self.wait_steps_remaining -= 1 + self.data_chunks.clear() + return + + self.observe_steps_remaining -= 1 + input_np = np.concatenate(self.data_chunks, axis=0) + self.data_chunks.clear() + self.clustering_model.partial_fit(input_np) + if self.observe_steps_remaining == 0: + print("\nInitializing vq clusters (this may take a while)...") + clusters, _ = self.clustering_model.get_final_clusters( + self.n_embed, seed=42 + ) + new_embed = torch.tensor( + clusters.T, dtype=self.embed.dtype, device=self.embed.device + ) + self.embed.copy_(new_embed) + # Don't set initial cluster sizes to zero! If a cluster is rare, + # embed_avg will be undergoing exponential decay until it's seen for + # the first time. If cluster_size is zero, this will lead to *embed* + # also undergoing exponential decay towards the origin before the + # cluster is ever encountered. Initializing to 1.0 will instead will + # instead leave embed in place for many iterations, up until + # cluster_size finally decays to near-zero. + self.cluster_size.fill_(1.0) + self.embed_avg.copy_(new_embed) + + def forward(self, input, expected_num_tokens=None): + if self.observe_steps_remaining > 0: + if self.training: + self.stream_cluster(input, expected_num_tokens) + return ( + input, + torch.zeros(input.shape[0], + dtype=torch.long, device=input.device), + torch.tensor(0.0, dtype=input.dtype, device=input.device), + None + ) + + dtype = input.dtype + flatten = input.reshape(-1, self.dim) + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True) + ) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = F.embedding(embed_ind, self.embed.transpose(0, 1)) + + if self.training: + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + ema_inplace(self.embed_avg, embed_sum, self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.n_embed, self.eps) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + loss = F.mse_loss(quantize.detach(), input) * self.commitment + quantize = input + (quantize - input).detach() + quantize = torch.reshape(quantize, input.size()) + return quantize, embed_ind, loss, dist diff --git a/tania_scripts/tania-some-other-metrics.ipynb b/tania_scripts/tania-some-other-metrics.ipynb index e23180be4b2b3ebb0fa7b86798291f716322a6a5..ceb6052a49e387b6e1af94e850c3bffb5c13fcd0 100644 --- a/tania_scripts/tania-some-other-metrics.ipynb +++ b/tania_scripts/tania-some-other-metrics.ipynb @@ -88,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "2a882cc9-8f9d-4457-becb-d2e26ab3f14f", "metadata": {}, "outputs": [ @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "8897dcc3-4218-4ee5-9984-17b9a6d8dce2", "metadata": {}, "outputs": [], @@ -163,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "1363f307-fa4b-43ba-93d5-2d1c11ceb9e4", "metadata": {}, "outputs": [ @@ -184,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "1362e192-514a-4a77-a8cb-5c012026e2bb", "metadata": {}, "outputs": [], @@ -238,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "544ff6aa-4104-4580-a01f-97429ffcc228", "metadata": {}, "outputs": [ @@ -328,7 +328,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "b9052dc2-ce45-4af4-a0a0-46c60a13da12", "metadata": {}, "outputs": [], @@ -388,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "1e9dd0fb-db6a-47d1-8bfb-1015845f6d3e", "metadata": {}, "outputs": [ @@ -398,7 +398,7 @@ "{'Flesch-Douma': 88.68, 'LIX': 11.55, 'Kandel-Moles': 5.86}" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -421,7 +421,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "24bc84a5-b2df-4194-838a-8f24302599bd", "metadata": {}, "outputs": [], @@ -467,7 +467,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "0cdb972f-31b6-4e7e-82a8-371eda344f2c", "metadata": {}, "outputs": [ @@ -477,7 +477,7 @@ "{'Average Word Length': 3.79, 'Average Sentence Length': 7.0}" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -521,7 +521,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "56af520c-d56b-404a-aebf-ad7c2a9ca503", "metadata": {}, "outputs": [], @@ -567,7 +567,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "f7c8b125-4651-4b21-bcc4-93ef78a4239b", "metadata": {}, "outputs": [ @@ -603,7 +603,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 15, "id": "daa17c33-adca-4695-90eb-741579382939", "metadata": {}, "outputs": [], @@ -622,7 +622,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 16, "id": "80d8fa08-6b7d-4ab7-85cd-987823639277", "metadata": {}, "outputs": [ @@ -665,7 +665,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 18, "id": "3f9c7dc7-6820-4013-a85c-2af4f846d4f5", "metadata": {}, "outputs": [ @@ -693,7 +693,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "id": "65e1a630-c46e-4b18-9831-b97864de53ee", "metadata": {}, "outputs": [], @@ -713,7 +713,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "id": "1612e911-12a8-47c9-b811-b2d6885c3647", "metadata": {}, "outputs": [ @@ -742,7 +742,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "id": "925a3a75-aaaa-4851-b77b-b42cb1e21e11", "metadata": {}, "outputs": [], @@ -757,7 +757,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "id": "6fa60897-ad26-43b4-b8de-861290ca6bd3", "metadata": {}, "outputs": [ @@ -783,25 +783,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 23, "id": "f3678462-e572-4ce5-8d3d-a5389b2356c8", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Defaulting to user installation because normal site-packages is not writeable\n", - "Collecting scipy\n", - " Downloading scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n", - "Requirement already satisfied: numpy<2.5,>=1.23.5 in /public/conda/Miniconda/envs/pytorch-2.6/lib/python3.11/site-packages (from scipy) (2.2.4)\n", - "Downloading scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m37.7/37.7 MB\u001b[0m \u001b[31m75.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", - "\u001b[?25hInstalling collected packages: scipy\n", - "Successfully installed scipy-1.15.3\n" - ] - } - ], + "outputs": [], "source": [ "#!pip3 install seaborn\n", "#!pip3 install scipy" @@ -809,7 +794,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 24, "id": "b621b2a8-488f-44db-b085-fe156f453943", "metadata": {}, "outputs": [ @@ -830,7 +815,7 @@ "import matplotlib.pyplot as plt\n", "from scipy.stats import spearmanr\n", "\n", - "# Sample data (replace with your real values)\n", + "# Sample data (to be replaces with real values)\n", "data = {\n", " \"perplexity\": [32.5, 45.2, 28.1, 39.0, 50.3],\n", " \"avg_word_length\": [4.1, 4.3, 4.0, 4.2, 4.5],\n", @@ -853,10 +838,123 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "45ee04fc-acab-4bba-ba06-e4cf4bca9fe5", + "metadata": {}, + "source": [ + "## Tree depth" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "79f99787-c220-4f1d-93a9-59230363ec3f", + "metadata": {}, + "outputs": [], + "source": [ + "def parse_sentence_block(text):\n", + " lines = text.strip().split('\\n')\n", + " result = []\n", + " tokenlist = []\n", + " for line in lines:\n", + " # Split the line by tab and strip whitespace\n", + " parts = tuple(line.strip().split('\\t'))\n", + " # Only include lines that have exactly 4 parts\n", + " if len(parts) == 4:\n", + " parentidx = int(parts[3])\n", + " if '@@' in parts[2]:\n", + " nonterm1 = parts[2].split('@@')[0]\n", + " nonterm2 = parts[2].split('@@')[1]\n", + " else:\n", + " nonterm1 = parts[2]\n", + " nonterm2 = '<nul>'\n", + " postag = parts[1]\n", + " token = parts[0]\n", + " result.append((parentidx, nonterm1, nonterm2, postag))\n", + " tokenlist.append(token)\n", + " return result, tokenlist\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "f567efb0-8b0b-4782-9345-052cf1785776", + "metadata": {}, + "outputs": [], + "source": [ + "example_sentence = \"\"\"\n", + "<s>\t<s>\t<s>\t1\n", + "--\tponct\t<nul>@@<nul>\t1\n", + "Eh\tnpp\t<nul>@@<nul>\t1\n", + "bien?\tadv\tAP@@<nul>\t1\n", + "fit\tv\tVN@@<nul>\t2\n", + "-il\tcls-suj\tVN@@VPinf-OBJ\t3\n", + ".\tponct\t<nul>@@<nul>\t4\n", + "</s>\t</s>\t</s>\t4\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "8d4ecba9-89b8-4000-a061-aa16aa68a404", + "metadata": {}, + "outputs": [], + "source": [ + "from transform import *\n", + "\n", + "def visualize_const_prediction(example_sent):\n", + " parsed, tokenlist = parse_sentence_block(example_sent)\n", + " tree = AttachJuxtaposeTree.totree(tokenlist, 'SENT')\n", + " AttachJuxtaposeTree.action2tree(tree, parsed).pretty_print()\n", + " nltk_tree = AttachJuxtaposeTree.action2tree(tree, parsed)\n", + " #print(\"NLTK TREE\", nltk_tree)\n", + " depth = nltk_tree.height() - 1 # NLTK includes the leaf level as height 1, so subtract 1 for tree depth \n", + " print(\"Tree depth:\", depth)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "bfd3abf3-b83a-4817-85ad-654daf72be88", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " SENT \n", + " | \n", + " <s> \n", + " | \n", + " <s> \n", + " ______________|__________ \n", + " | | | AP \n", + " | | | __________|________ \n", + " | | | | VPinf-OBJ \n", + " | | | | ______________|_______ \n", + " | | | | | VN \n", + " | | | | | ________________|____ \n", + " | | | | VN | | </s>\n", + " | | | | | | | | \n", + " | ponct npp adv v cls-suj ponct </s>\n", + " | | | | | | | | \n", + "<s> -- Eh bien? fit -il . </s>\n", + "\n", + "Tree depth: 8\n" + ] + } + ], + "source": [ + "visualize_const_prediction(example_sentence)" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "3a6e3b53-7104-45ef-a4b5-e831bdd6ca6f", + "id": "bc51ab44-6885-45cc-bad2-6a43a7791fdb", "metadata": {}, "outputs": [], "source": [] diff --git a/tania_scripts/transform.py b/tania_scripts/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..56e0f5158b6b3b214b7e99e8a479fbabc56bcbf6 --- /dev/null +++ b/tania_scripts/transform.py @@ -0,0 +1,459 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union + +import nltk +import torch + +from supar.models.const.crf.transform import Tree +from supar.utils.common import NUL +from supar.utils.logging import get_logger +from supar.utils.tokenizer import Tokenizer +from supar.utils.transform import Sentence + +if TYPE_CHECKING: + from supar.utils import Field + +logger = get_logger(__name__) + + +class AttachJuxtaposeTree(Tree): + r""" + :class:`AttachJuxtaposeTree` is derived from the :class:`Tree` class, + supporting back-and-forth transformations between trees and AttachJuxtapose actions :cite:`yang-deng-2020-aj`. + + Attributes: + WORD: + Words in the sentence. + POS: + Part-of-speech tags, or underscores if not available. + TREE: + The raw constituency tree in :class:`nltk.tree.Tree` format. + NODE: + The target node on each rightmost chain. + PARENT: + The label of the parent node of each terminal. + NEW: + The label of each newly inserted non-terminal with a target node and a terminal as juxtaposed children. + ``NUL`` represents the `Attach` action. + """ + + fields = ['WORD', 'POS', 'TREE', 'NODE', 'PARENT', 'NEW'] + + def __init__( + self, + WORD: Optional[Union[Field, Iterable[Field]]] = None, + POS: Optional[Union[Field, Iterable[Field]]] = None, + TREE: Optional[Union[Field, Iterable[Field]]] = None, + NODE: Optional[Union[Field, Iterable[Field]]] = None, + PARENT: Optional[Union[Field, Iterable[Field]]] = None, + NEW: Optional[Union[Field, Iterable[Field]]] = None + ) -> Tree: + super().__init__() + + self.WORD = WORD + self.POS = POS + self.TREE = TREE + self.NODE = NODE + self.PARENT = PARENT + self.NEW = NEW + + @property + def src(self): + return self.WORD, self.POS, self.TREE + + @property + def tgt(self): + return self.NODE, self.PARENT, self.NEW + + @classmethod + def tree2action(cls, tree: nltk.Tree): + r""" + Converts a constituency tree into AttachJuxtapose actions. + + Args: + tree (nltk.tree.Tree): + A constituency tree in :class:`nltk.tree.Tree` format. + + Returns: + A sequence of AttachJuxtapose actions. + + Examples: + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree + >>> tree = nltk.Tree.fromstring(''' + (TOP + (S + (NP (_ Arthur)) + (VP + (_ is) + (NP (NP (_ King)) (PP (_ of) (NP (_ the) (_ Britons))))) + (_ .))) + ''') + >>> tree.pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + >>> AttachJuxtaposeTree.tree2action(tree) + [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'), + (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'), + (0, '<nul>', '<nul>')] + """ + + def isroot(node): + return node == tree[0] + + def isterminal(node): + return len(node) == 1 and not isinstance(node[0], nltk.Tree) + + def last_leaf(node): + pos = () + while True: + pos += (len(node) - 1,) + node = node[-1] + if isterminal(node): + return node, pos + + def parent(position): + return tree[position[:-1]] + + def grand(position): + return tree[position[:-2]] + + def detach(tree): + last, last_pos = last_leaf(tree) + siblings = parent(last_pos)[:-1] + + if len(siblings) > 0: + last_subtree = last + last_subtree_siblings = siblings + parent_label = NUL + else: + last_subtree, last_pos = parent(last_pos), last_pos[:-1] + last_subtree_siblings = [] if isroot(last_subtree) else parent(last_pos)[:-1] + parent_label = last_subtree.label() + + target_pos, new_label, last_tree = 0, NUL, tree + if isroot(last_subtree): + last_tree = None + + elif len(last_subtree_siblings) == 1 and not isterminal(last_subtree_siblings[0]): + new_label = parent(last_pos).label() + new_label = new_label + target = last_subtree_siblings[0] + last_grand = grand(last_pos) + if last_grand is None: + last_tree = targetistermina + else: + last_grand[-1] = target + target_pos = len(last_pos) - 2 + else: + target = parent(last_pos) + target.pop() + target_pos = len(last_pos) - 2 + action = target_pos, parent_label, new_label + return action, last_tree + if tree is None: + return [] + action, last_tree = detach(tree) + return cls.tree2action(last_tree) + [action] + + @classmethod + def action2tree( + cls, + tree: nltk.Tree, + actions: List[Tuple[int, str, str]], + join: str = '::', + ) -> nltk.Tree: + r""" + Recovers a constituency tree from a sequence of AttachJuxtapose actions. + + Args: + tree (nltk.tree.Tree): + An empty tree that provides a base for building a result tree. + actions (List[Tuple[int, str, str]]): + A sequence of AttachJuxtapose actions. + join (str): + A string used to connect collapsed node labels. Non-terminals containing this will be expanded to unary chains. + Default: ``'::'``. + + Returns: + A result constituency tree. + + Examples: + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree + >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP') + >>> AttachJuxtaposeTree.action2tree(tree, + [(0, 'NP', '<nul>'), (0, 'VP', 'S'), (1, 'NP', '<nul>'), + (2, 'PP', 'NP'), (3, 'NP', '<nul>'), (4, '<nul>', '<nul>'), + (0, '<nul>', '<nul>')]).pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + """ + + def target(node, depth): + node_pos = () + for _ in range(depth): + node_pos += (len(node) - 1,) + node = node[-1] + return node, node_pos + + def parent(tree, position): + return tree[position[:-1]] + + def execute(tree: nltk.Tree, terminal: Tuple(str, str), action: Tuple[int, str, str]) -> nltk.Tree: + target_pos, parent_label, new_label, post = action + #print(target_pos, parent_label, new_label) + new_leaf = nltk.Tree(post, [terminal[0]]) + + # create the subtree to be inserted + new_subtree = new_leaf if parent_label == NUL else nltk.Tree(parent_label, [new_leaf]) + # find the target position at which to insert the new subtree + target_node = tree + if target_node is not None: + target_node, target_pos = target(target_node, target_pos) + + # Attach + if new_label == NUL: + # attach the first token + if target_node is None: + return new_subtree + target_node.append(new_subtree) + # Juxtapose + else: + new_subtree = nltk.Tree(new_label, [target_node, new_subtree]) + if len(target_pos) > 0: + parent_node = parent(tree, target_pos) + parent_node[-1] = new_subtree + else: + tree = new_subtree + return tree + + tree, root, terminals = None, tree.label(), tree.pos() + for terminal, action in zip(terminals, actions): + tree = execute(tree, terminal, action) + # recover unary chains + nodes = [tree] + while nodes: + node = nodes.pop() + if isinstance(node, nltk.Tree): + nodes.extend(node) + if join in node.label(): + labels = node.label().split(join) + node.set_label(labels[0]) + subtree = nltk.Tree(labels[-1], node) + for label in reversed(labels[1:-1]): + subtree = nltk.Tree(label, [subtree]) + node[:] = [subtree] + return nltk.Tree(root, [tree]) + + @classmethod + def action2span( + cls, + action: torch.Tensor, + spans: torch.Tensor = None, + nul_index: int = -1, + mask: torch.BoolTensor = None + ) -> torch.Tensor: + r""" + Converts a batch of the tensorized action at a given step into spans. + + Args: + action (~torch.Tensor): ``[3, batch_size]``. + A batch of the tensorized action at a given step, containing indices of target nodes, parent and new labels. + spans (~torch.Tensor): + Spans generated at previous steps, ``None`` at the first step. Default: ``None``. + nul_index (int): + The index for the obj:`NUL` token, representing the Attach action. Default: -1. + mask (~torch.BoolTensor): ``[batch_size]``. + The mask for covering the unpadded tokens. + + Returns: + A tensor representing a batch of spans for the given step. + + Examples: + >>> from collections import Counter + >>> from supar.models.const.aj.transform import AttachJuxtaposeTree, Vocab + >>> from supar.utils.common import NUL + >>> nodes, parents, news = zip(*[(0, 'NP', NUL), (0, 'VP', 'S'), (1, 'NP', NUL), + (2, 'PP', 'NP'), (3, 'NP', NUL), (4, NUL, NUL), + (0, NUL, NUL)]) + >>> vocab = Vocab(Counter(sorted(set([*parents, *news])))) + >>> actions = torch.tensor([nodes, vocab[parents], vocab[news]]).unsqueeze(1) + >>> spans = None + >>> for action in actions.unbind(-1): + ... spans = AttachJuxtaposeTree.action2span(action, spans, vocab[NUL]) + ... + >>> spans + tensor([[[-1, 1, -1, -1, -1, -1, -1, 3], + [-1, -1, -1, -1, -1, -1, 4, -1], + [-1, -1, -1, 1, -1, -1, 1, -1], + [-1, -1, -1, -1, -1, -1, 2, -1], + [-1, -1, -1, -1, -1, -1, 1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1]]]) + >>> sequence = torch.where(spans.ge(0)) + >>> sequence = list(zip(sequence[1].tolist(), sequence[2].tolist(), vocab[spans[sequence]])) + >>> sequence + [(0, 1, 'NP'), (0, 7, 'S'), (1, 6, 'VP'), (2, 3, 'NP'), (2, 6, 'NP'), (3, 6, 'PP'), (4, 6, 'NP')] + >>> tree = AttachJuxtaposeTree.totree(['Arthur', 'is', 'King', 'of', 'the', 'Britons', '.'], 'TOP') + >>> AttachJuxtaposeTree.build(tree, sequence).pretty_print() + TOP + | + S + ______________|_______________________ + | VP | + | ________|___ | + | | NP | + | | ________|___ | + | | | PP | + | | | _______|___ | + NP | NP | NP | + | | | | ___|_____ | + _ _ _ _ _ _ _ + | | | | | | | + Arthur is King of the Britons . + + """ + + # [batch_size] + target, parent, new = action + if spans is None: + spans = action.new_full((action.shape[1], 2, 2), -1) + spans[:, 0, 1] = parent + return spans + if mask is None: + mask = torch.ones_like(target, dtype=bool) + juxtapose_mask = new.ne(nul_index) & mask + # ancestor nodes are those on the rightmost chain and higher than the target node + # [batch_size, seq_len] + rightmost_mask = spans[..., -1].ge(0) + ancestors = rightmost_mask.cumsum(-1).masked_fill_(~rightmost_mask, -1) - 1 + # should not include the target node for the Juxtapose action + ancestor_mask = mask.unsqueeze(-1) & ancestors.ge(0) & ancestors.le((target - juxtapose_mask.long()).unsqueeze(-1)) + target_pos = torch.where(ancestors.eq(target.unsqueeze(-1))[juxtapose_mask])[-1] + # the right boundaries of ancestor nodes should be aligned with the new generated terminals + spans = torch.cat((spans, torch.where(ancestor_mask, spans[..., -1], -1).unsqueeze(-1)), -1) + spans[..., -2].masked_fill_(ancestor_mask, -1) + spans[juxtapose_mask, target_pos, -1] = new.masked_fill(new.eq(nul_index), -1)[juxtapose_mask] + spans[mask, -1, -1] = parent.masked_fill(parent.eq(nul_index), -1)[mask] + # [batch_size, seq_len+1, seq_len+1] + spans = torch.cat((spans, torch.full_like(spans[:, :1], -1)), 1) + return spans + + def load( + self, + data: Union[str, Iterable], + lang: Optional[str] = None, + **kwargs + ) -> List[AttachJuxtaposeTreeSentence]: + r""" + Args: + data (Union[str, Iterable]): + A filename or a list of instances. + lang (str): + Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. + ``None`` if tokenization is not required. + Default: ``None``. + + Returns: + A list of :class:`AttachJuxtaposeTreeSentence` instances. + """ + + if lang is not None: + tokenizer = Tokenizer(lang) + if isinstance(data, str) and os.path.exists(data): + if data.endswith('.txt'): + data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1) + else: + data = open(data) + else: + if lang is not None: + data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)] + else: + data = [data] if isinstance(data[0], str) else data + + index = 0 + for s in data: + + try: + tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root) + sentence = AttachJuxtaposeTreeSentence(self, tree, index) + except ValueError: + logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!") + continue + except IndexError: + tree = nltk.Tree.fromstring('(S ' + s + ')') + sentence = AttachJuxtaposeTreeSentence(self, tree, index) + else: + yield sentence + index += 1 + self.root = tree.label() + + +class AttachJuxtaposeTreeSentence(Sentence): + r""" + Args: + transform (AttachJuxtaposeTree): + A :class:`AttachJuxtaposeTree` object. + tree (nltk.tree.Tree): + A :class:`nltk.tree.Tree` object. + index (Optional[int]): + Index of the sentence in the corpus. Default: ``None``. + """ + + def __init__( + self, + transform: AttachJuxtaposeTree, + tree: nltk.Tree, + index: Optional[int] = None + ) -> AttachJuxtaposeTreeSentence: + super().__init__(transform, index) + + words, tags = zip(*tree.pos()) + nodes, parents, news = None, None, None + if transform.training: + oracle_tree = tree.copy(True) + # the root node must have a unary chain + if len(oracle_tree) > 1: + oracle_tree[:] = [nltk.Tree('*', oracle_tree)] + oracle_tree.collapse_unary(joinChar='::') + if len(oracle_tree) == 1 and not isinstance(oracle_tree[0][0], nltk.Tree): + oracle_tree[0] = nltk.Tree('*', [oracle_tree[0]]) + nodes, parents, news = zip(*transform.tree2action(oracle_tree)) + tags = [x.split("##")[0] for x in tags] + self.values = [words, tags, tree, nodes, parents, news] + + def __repr__(self): + return self.values[-4].pformat(1000000) + + def pretty_print(self): + self.values[-4].pretty_print()