diff --git a/.options.txt b/.options.txt index 519358cb5e3e87a834a30531b70c8790c8618548..500e4c227f5eeac0af36752c825618101c7e79e1 100644 --- a/.options.txt +++ b/.options.txt @@ -1,4 +1,5 @@ +slu: 1 xml_filename: data/homeostasis_25nov.xml osc_host: 127.0.0.1 osc_port: 1234 -asr_model: asr/fisher-benoit-1.cfg +asr_model: asr/mika-fred-2.cfg diff --git a/alternate_slu.py b/alternate_slu.py new file mode 100644 index 0000000000000000000000000000000000000000..22b9677a8f6fc4565b784760b6621dc5022c0361 --- /dev/null +++ b/alternate_slu.py @@ -0,0 +1,76 @@ +from xml.etree import ElementTree as ET +from collections import defaultdict + +history_length = 5 +_index = defaultdict(list) + +class Keyword: + def __init__(self, num, history, word, action, section): + self.history = history + self.word = word + self.num = num + self.action = action + self.section = int(section) - 1 + for word in history: + _index[word].append(num) + +class SLU: + def __init__(self, xml_filename): + self.words = [] + self.parse_xml(xml_filename) + + def parse_xml(self, filename): + root = ET.parse(filename) + for section in root.findall(".//section"): + history = [x for x in reversed(range(history_length))] + for sequence in section.findall('./sequence'): + history.extend(sequence.text.strip().split()) + for node in sequence: + if node.tag == 'keyword': + history.extend(node.text.strip().split()) + if node.get('action').strip() != '': + self.words.append(Keyword(len(self.words), history[-history_length: -1], history[-1], node.get('action'), section.get("id"))) + history.extend(node.tail.strip().split()) + + def process(self, history, target, expected=-1, section=-1): + history = ([x for x in reversed(range(history_length))] + history)[-history_length:] + found = set() + for word in history: + found.update(_index[word]) + #print len(found) + + history = set(history) + min_distance = None + max_value = 0 + argmax = None + + for word in [self.words[num] for num in found]: + if section != -1 and word.section != section: + continue + history_score = len(history & set(word.history)) / float(history_length) + target_score = len(set(word.word) & set(target)) / float(len(target)) + if history_score == 0 or target_score == 0: + continue + value = history_score + target_score + if value > max_value: + max_value = value + argmax = word + if expected != -1 and value == max_value and (min_distance == None or abs(word.num - expected) < min_distance): + min_distance = abs(word.num - expected) + max_value = value + argmax = word + + return argmax + + def last_in_section(self, word): + if word.num == len(self.words) - 1 or self.words[word.num + 1].section != word.section: + return True + return False + +if __name__ == '__main__': + slu = SLU('data/homeostasis_25nov.xml') + history = 'open technical'.split() + word = 'characteristics' + found = slu.process(history, word) + print found.word, found.section, found.action, found.history, word + diff --git a/animate.py b/animate.py index 4d4eea10666efb06e51d0c44b422272eb2a626cd..96985621a76a05333c4c5011b2c9bae381106ae6 100644 --- a/animate.py +++ b/animate.py @@ -31,4 +31,4 @@ def scroll_to(scrollview, widget): result = widget.translate_coordinates(scrollview.get_child().get_child(), 0, 0) if result: adj = scrollview.get_vadjustment() - animate_value(adj.set_value, adj.get_value(), result[1]) + animate_value(adj.set_value, adj.get_value(), result[1] - 150) diff --git a/asr.py b/asr.py index 54166c0d1079a7321bf12a2eedc96996aa0976c3..9558fb44b283f7671c0a3ef829c351d1db6b0965 100644 --- a/asr.py +++ b/asr.py @@ -11,7 +11,7 @@ Gst.init(None) import osc class ASR(Gtk.HBox): - def __init__(self, asr_config, hyp_callback = None, partial_hyp_callback = None): + def __init__(self, asr_config, partial_hyp_callback = None, hyp_callback = None): super(ASR, self).__init__() self.text = Gtk.TextView() diff --git a/main.py b/main.py index 666b1de6eea42f1f1e8d8f62fc4ab971e006b747..412f0e18cdf6a444db68a5150752f18bf50a9977 100644 --- a/main.py +++ b/main.py @@ -28,13 +28,13 @@ import confirm, asr, actions, xmlview import levenstein, slu, osc class ScriptedASR(Gtk.Window): - def __init__(self, xml_filename, asr_config_file, osc_host, osc_port): + def __init__(self, xml_filename, asr_config_file, osc_host, osc_port, slu_type): super(ScriptedASR, self).__init__() self.connect("destroy", self.quit) self.set_default_size(1024, 768) self.set_border_width(10) - self.set_title('ASR Transcript [xml=%s asr=%s osc=%s:%s]' % (xml_filename, asr_config_file, osc_host, osc_port)) + self.set_title('ASR Transcript [xml=%s asr=%s osc=%s:%s slu=%d]' % (xml_filename, asr_config_file, osc_host, osc_port, slu_type)) vbox = Gtk.VBox() self.xmlview = xmlview.XmlView(xml_filename) @@ -44,7 +44,10 @@ class ScriptedASR(Gtk.Window): vbox.pack_start(self.confirmer, False, True, 5) # transcript view - self.asr = asr.ASR(asr_config_file, self.hyp_changed) + if slu_type == 1: + self.asr = asr.ASR(asr_config_file, self.hyp_changed2) + else: + self.asr = asr.ASR(asr_config_file, self.hyp_changed) vbox.pack_start(self.asr, False, True, 5) # slu @@ -59,6 +62,13 @@ class ScriptedASR(Gtk.Window): section_id = int(found.group(1)) self.slu[section_id - 1] = slu.SLU(prefix % 'dico_word.txt', prefix % 'dico_action.txt', section_fst, prefix % 'clean_tail.fst', library=library) + import alternate_slu + self.slu2 = alternate_slu.SLU(xml_filename) + self.slu2_expected = 0 + self.slu2_performed = set() + self.slu2_last_section = -1 + self.slu2_history = [] + self.add(vbox) self.show_all() @@ -162,6 +172,24 @@ class ScriptedASR(Gtk.Window): model.process(words, self.kept_history, self.slu_finished, False) self.current_section_history[-1] = ' '.join(words) + def hyp_changed2(self, hyp): + section_id = self.xmlview.get_section() + if self.slu2_last_section != section_id: + self.slu2_last_section = section_id + self.slu2_performed = set() + self.slu2_history = [] + self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split()) + found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id) + if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]: + self.slu2_performed.add(found.num) + print found.section, found.action, found.history, found.word + found.keyword = self.xmlview.keywords[found.num] + self.xmlview.highlight(found) + self.slu2_expected = found.num + 1 + if self.slu2.last_in_section(found): + new_section = self.xmlview.get_section() + 1 + self.confirmer.confirm('Go to section %d?' % (new_section + 1), 3, lambda: self.set_section(new_section)) + def quit(self, window): for slu in self.slu.values(): slu.shutdown() @@ -178,8 +206,8 @@ if __name__ == '__main__': xml_filename = sys.argv[1] if len(sys.argv) > 2: asr_config_file = sys.argv[2] - xml_filename, asr_config_file, osc_host, osc_port = selector.ModelSelector(xml_filename, asr_config_file).run() - if xml_filename == None or asr_config_file == None or osc_host == None or osc_port == None: + xml_filename, asr_config_file, osc_host, osc_port, slu_type = selector.ModelSelector(xml_filename, asr_config_file).run() + if xml_filename == None or asr_config_file == None or osc_host == None or osc_port == None or slu_type == None: sys.exit(0) - app = ScriptedASR(xml_filename, asr_config_file, osc_host, osc_port) + app = ScriptedASR(xml_filename, asr_config_file, osc_host, osc_port, slu_type) Gtk.main() diff --git a/selector.py b/selector.py index c995378c055e89057875183cc2036fb871d5d8c2..8cf50a9ac5701c4bcf972b4190de1453a1c667ec 100644 --- a/selector.py +++ b/selector.py @@ -5,7 +5,7 @@ import config class ModelSelector(Gtk.Dialog): def __init__(self, xml_filename = '', asr_model = ''): super(ModelSelector, self).__init__() - self.options = {'xml_filename': xml_filename, 'asr_model': asr_model, 'osc_host': '127.0.0.1', 'osc_port': '1234'} + self.options = {'xml_filename': xml_filename, 'asr_model': asr_model, 'osc_host': '127.0.0.1', 'osc_port': '1234', 'slu': 'Regular'} self.load_options() self.set_title('Configuration') @@ -42,6 +42,18 @@ class ModelSelector(Gtk.Dialog): box.pack_start(model_box, False, False, 5) + slu_box =Gtk.HBox() + slu_box.pack_start(Gtk.Label('SLU model:'), False, False, 5) + slu_chooser = Gtk.ComboBoxText() + slu_chooser.set_entry_text_column(0) + for i, slu in enumerate(['Regular', 'Alternative']): + slu_chooser.append_text(slu) + slu_chooser.set_active(int(self.options['slu'])) + self.slu_chooser = slu_chooser + slu_box.pack_start(slu_chooser, True, True, 5) + + box.pack_start(slu_box, False, False, 5) + osc_box = Gtk.HBox() osc_box.pack_start(Gtk.Label('OSC host:'), False, False, 5) osc_host = Gtk.Entry() @@ -123,7 +135,8 @@ class ModelSelector(Gtk.Dialog): self.options['xml_filename'] = self.xml_entry.get_text() self.options['osc_host'] = self.osc_host.get_text() self.options['osc_port'] = self.osc_port.get_text() + self.options['slu'] = self.slu_chooser.get_active() self.save_options() self.destroy() - return self.options['xml_filename'], self.options['asr_model'], self.options['osc_host'], self.options['osc_port'] + return self.options['xml_filename'], self.options['asr_model'], self.options['osc_host'], self.options['osc_port'], self.options['slu'] diff --git a/xmlview.py b/xmlview.py index 8fdae41ab03bef6a53a74cba885216810f1e9db9..9d4cd723245e0ef24785288a6b627f1709a40f09 100644 --- a/xmlview.py +++ b/xmlview.py @@ -86,11 +86,13 @@ class Line(Gtk.HBox): self.elements = elements self.get_style_context().add_class('text-line') - def highlight(self, active=True): + def highlight(self, active=True, scrollable=None): if active: self.label.get_style_context().add_class('highlighted') else: self.label.get_style_context().remove_class('highlighted') + if scrollable != None: + animate.scroll_to(scrollable, self) class Keyword(Gtk.Label): def __init__(self, text, action, lang): @@ -102,9 +104,11 @@ class Keyword(Gtk.Label): self.get_style_context().add_class('keyword') self.connect('activate-link', self.link_clicked) - def highlight(self): + def highlight(self, scrollable=None): self.get_style_context().remove_class('keyword') self.get_style_context().add_class('keyword-highlighted') + if scrollable != None: + animate.scroll_to(scrollable, self) def link_clicked(self, widget, uri): actions.perform_action(actions.Action(uri, keyword=widget), False) @@ -171,12 +175,12 @@ class XmlView(Gtk.ScrolledWindow): def highlight(self, action): if hasattr(action, 'keyword'): - action.keyword.highlight() + action.keyword.highlight(self) else: i = self.last_highlighted + 1 while i < len(self.keywords): if self.keywords[i].action == action.text: - self.keywords[i].highlight() + self.keywords[i].highlight(self) self.last_highlighted = i break i += 1