Skip to content
Snippets Groups Projects
alternate_slu.py 3.06 KiB
from xml.etree import ElementTree as ET
from collections import defaultdict

history_length = 5
_index = defaultdict(list)

class Keyword:
    def __init__(self, num, history, word, action, section):
        self.history = history
        self.word = word
        self.num = num
        self.action = action
        self.section = int(section) - 1
        for word in history:
            _index[word].append(num)

class SLU:
    def __init__(self, xml_filename):
        self.words = []
        self.parse_xml(xml_filename)

    def parse_xml(self, filename):
        root = ET.parse(filename)
        for section in root.findall(".//section"):
            history = [x for x in reversed(range(history_length))]
            for sequence in section.findall('./sequence'):
                history.extend(sequence.text.strip().split())
                for node in sequence:
                    if node.tag == 'keyword':
                        history.extend(node.text.strip().split())
                        if node.get('action').strip() != '':
                            self.words.append(Keyword(len(self.words), history[-history_length: -1], history[-1], node.get('action'), section.get("id")))
                    history.extend(node.tail.strip().split())

    def process(self, history, target, expected=-1, section=-1):
        history = ([x for x in reversed(range(history_length))] + history)[-history_length:]
        found = set()
        for word in history:
            found.update(_index[word])
        #print len(found)

        history = set(history)
        min_distance = None
        max_value = 0
        argmax = None

        for word in [self.words[num] for num in found]:
            if section != -1 and word.section != section:
                continue
            history_score = len(history & set(word.history)) / float(history_length)
            target_score = len(set(word.word) & set(target)) / float(len(target))
            distance = abs(word.num - expected)
            if history_score == 0 or target_score == 0:
                continue
            if expected != -1 and distance > 10:
                continue
            value = history_score + target_score
            if value > max_value:
                max_value = value
                argmax = word
            if expected != -1 and value == max_value and (min_distance == None or distance < min_distance):
                min_distance = distance
                max_value = value
                argmax = word

        return argmax

    def last_in_section(self, word):
        if word.num == len(self.words) - 1 or self.words[word.num + 1].section != word.section:
            return True
        return False

    def expected_at_section_start(self, section):
        for word in self.words:
            if word.section == section:
                return word.num
        return -1

if __name__ == '__main__':
    slu = SLU('data/homeostasis_25nov.xml')
    history = 'open technical'.split()
    word = 'characteristics'
    found = slu.process(history, word)
    print found.word, found.section, found.action, found.history, word