Skip to content
Snippets Groups Projects
Commit 234a6d88 authored by Benoit Favre's avatar Benoit Favre
Browse files

add alternate slu

parent ff0185bd
No related branches found
No related tags found
No related merge requests found
slu: 1
xml_filename: data/homeostasis_25nov.xml xml_filename: data/homeostasis_25nov.xml
osc_host: 127.0.0.1 osc_host: 127.0.0.1
osc_port: 1234 osc_port: 1234
asr_model: asr/fisher-benoit-1.cfg asr_model: asr/mika-fred-2.cfg
from xml.etree import ElementTree as ET
from collections import defaultdict
history_length = 5
_index = defaultdict(list)
class Keyword:
def __init__(self, num, history, word, action, section):
self.history = history
self.word = word
self.num = num
self.action = action
self.section = int(section) - 1
for word in history:
_index[word].append(num)
class SLU:
def __init__(self, xml_filename):
self.words = []
self.parse_xml(xml_filename)
def parse_xml(self, filename):
root = ET.parse(filename)
for section in root.findall(".//section"):
history = [x for x in reversed(range(history_length))]
for sequence in section.findall('./sequence'):
history.extend(sequence.text.strip().split())
for node in sequence:
if node.tag == 'keyword':
history.extend(node.text.strip().split())
if node.get('action').strip() != '':
self.words.append(Keyword(len(self.words), history[-history_length: -1], history[-1], node.get('action'), section.get("id")))
history.extend(node.tail.strip().split())
def process(self, history, target, expected=-1, section=-1):
history = ([x for x in reversed(range(history_length))] + history)[-history_length:]
found = set()
for word in history:
found.update(_index[word])
#print len(found)
history = set(history)
min_distance = None
max_value = 0
argmax = None
for word in [self.words[num] for num in found]:
if section != -1 and word.section != section:
continue
history_score = len(history & set(word.history)) / float(history_length)
target_score = len(set(word.word) & set(target)) / float(len(target))
if history_score == 0 or target_score == 0:
continue
value = history_score + target_score
if value > max_value:
max_value = value
argmax = word
if expected != -1 and value == max_value and (min_distance == None or abs(word.num - expected) < min_distance):
min_distance = abs(word.num - expected)
max_value = value
argmax = word
return argmax
def last_in_section(self, word):
if word.num == len(self.words) - 1 or self.words[word.num + 1].section != word.section:
return True
return False
if __name__ == '__main__':
slu = SLU('data/homeostasis_25nov.xml')
history = 'open technical'.split()
word = 'characteristics'
found = slu.process(history, word)
print found.word, found.section, found.action, found.history, word
...@@ -31,4 +31,4 @@ def scroll_to(scrollview, widget): ...@@ -31,4 +31,4 @@ def scroll_to(scrollview, widget):
result = widget.translate_coordinates(scrollview.get_child().get_child(), 0, 0) result = widget.translate_coordinates(scrollview.get_child().get_child(), 0, 0)
if result: if result:
adj = scrollview.get_vadjustment() adj = scrollview.get_vadjustment()
animate_value(adj.set_value, adj.get_value(), result[1]) animate_value(adj.set_value, adj.get_value(), result[1] - 150)
...@@ -11,7 +11,7 @@ Gst.init(None) ...@@ -11,7 +11,7 @@ Gst.init(None)
import osc import osc
class ASR(Gtk.HBox): class ASR(Gtk.HBox):
def __init__(self, asr_config, hyp_callback = None, partial_hyp_callback = None): def __init__(self, asr_config, partial_hyp_callback = None, hyp_callback = None):
super(ASR, self).__init__() super(ASR, self).__init__()
self.text = Gtk.TextView() self.text = Gtk.TextView()
......
...@@ -28,13 +28,13 @@ import confirm, asr, actions, xmlview ...@@ -28,13 +28,13 @@ import confirm, asr, actions, xmlview
import levenstein, slu, osc import levenstein, slu, osc
class ScriptedASR(Gtk.Window): class ScriptedASR(Gtk.Window):
def __init__(self, xml_filename, asr_config_file, osc_host, osc_port): def __init__(self, xml_filename, asr_config_file, osc_host, osc_port, slu_type):
super(ScriptedASR, self).__init__() super(ScriptedASR, self).__init__()
self.connect("destroy", self.quit) self.connect("destroy", self.quit)
self.set_default_size(1024, 768) self.set_default_size(1024, 768)
self.set_border_width(10) self.set_border_width(10)
self.set_title('ASR Transcript [xml=%s asr=%s osc=%s:%s]' % (xml_filename, asr_config_file, osc_host, osc_port)) self.set_title('ASR Transcript [xml=%s asr=%s osc=%s:%s slu=%d]' % (xml_filename, asr_config_file, osc_host, osc_port, slu_type))
vbox = Gtk.VBox() vbox = Gtk.VBox()
self.xmlview = xmlview.XmlView(xml_filename) self.xmlview = xmlview.XmlView(xml_filename)
...@@ -44,7 +44,10 @@ class ScriptedASR(Gtk.Window): ...@@ -44,7 +44,10 @@ class ScriptedASR(Gtk.Window):
vbox.pack_start(self.confirmer, False, True, 5) vbox.pack_start(self.confirmer, False, True, 5)
# transcript view # transcript view
self.asr = asr.ASR(asr_config_file, self.hyp_changed) if slu_type == 1:
self.asr = asr.ASR(asr_config_file, self.hyp_changed2)
else:
self.asr = asr.ASR(asr_config_file, self.hyp_changed)
vbox.pack_start(self.asr, False, True, 5) vbox.pack_start(self.asr, False, True, 5)
# slu # slu
...@@ -59,6 +62,13 @@ class ScriptedASR(Gtk.Window): ...@@ -59,6 +62,13 @@ class ScriptedASR(Gtk.Window):
section_id = int(found.group(1)) section_id = int(found.group(1))
self.slu[section_id - 1] = slu.SLU(prefix % 'dico_word.txt', prefix % 'dico_action.txt', section_fst, prefix % 'clean_tail.fst', library=library) self.slu[section_id - 1] = slu.SLU(prefix % 'dico_word.txt', prefix % 'dico_action.txt', section_fst, prefix % 'clean_tail.fst', library=library)
import alternate_slu
self.slu2 = alternate_slu.SLU(xml_filename)
self.slu2_expected = 0
self.slu2_performed = set()
self.slu2_last_section = -1
self.slu2_history = []
self.add(vbox) self.add(vbox)
self.show_all() self.show_all()
...@@ -162,6 +172,24 @@ class ScriptedASR(Gtk.Window): ...@@ -162,6 +172,24 @@ class ScriptedASR(Gtk.Window):
model.process(words, self.kept_history, self.slu_finished, False) model.process(words, self.kept_history, self.slu_finished, False)
self.current_section_history[-1] = ' '.join(words) self.current_section_history[-1] = ' '.join(words)
def hyp_changed2(self, hyp):
section_id = self.xmlview.get_section()
if self.slu2_last_section != section_id:
self.slu2_last_section = section_id
self.slu2_performed = set()
self.slu2_history = []
self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split())
found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id)
if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]:
self.slu2_performed.add(found.num)
print found.section, found.action, found.history, found.word
found.keyword = self.xmlview.keywords[found.num]
self.xmlview.highlight(found)
self.slu2_expected = found.num + 1
if self.slu2.last_in_section(found):
new_section = self.xmlview.get_section() + 1
self.confirmer.confirm('Go to section %d?' % (new_section + 1), 3, lambda: self.set_section(new_section))
def quit(self, window): def quit(self, window):
for slu in self.slu.values(): for slu in self.slu.values():
slu.shutdown() slu.shutdown()
...@@ -178,8 +206,8 @@ if __name__ == '__main__': ...@@ -178,8 +206,8 @@ if __name__ == '__main__':
xml_filename = sys.argv[1] xml_filename = sys.argv[1]
if len(sys.argv) > 2: if len(sys.argv) > 2:
asr_config_file = sys.argv[2] asr_config_file = sys.argv[2]
xml_filename, asr_config_file, osc_host, osc_port = selector.ModelSelector(xml_filename, asr_config_file).run() xml_filename, asr_config_file, osc_host, osc_port, slu_type = selector.ModelSelector(xml_filename, asr_config_file).run()
if xml_filename == None or asr_config_file == None or osc_host == None or osc_port == None: if xml_filename == None or asr_config_file == None or osc_host == None or osc_port == None or slu_type == None:
sys.exit(0) sys.exit(0)
app = ScriptedASR(xml_filename, asr_config_file, osc_host, osc_port) app = ScriptedASR(xml_filename, asr_config_file, osc_host, osc_port, slu_type)
Gtk.main() Gtk.main()
...@@ -5,7 +5,7 @@ import config ...@@ -5,7 +5,7 @@ import config
class ModelSelector(Gtk.Dialog): class ModelSelector(Gtk.Dialog):
def __init__(self, xml_filename = '', asr_model = ''): def __init__(self, xml_filename = '', asr_model = ''):
super(ModelSelector, self).__init__() super(ModelSelector, self).__init__()
self.options = {'xml_filename': xml_filename, 'asr_model': asr_model, 'osc_host': '127.0.0.1', 'osc_port': '1234'} self.options = {'xml_filename': xml_filename, 'asr_model': asr_model, 'osc_host': '127.0.0.1', 'osc_port': '1234', 'slu': 'Regular'}
self.load_options() self.load_options()
self.set_title('Configuration') self.set_title('Configuration')
...@@ -42,6 +42,18 @@ class ModelSelector(Gtk.Dialog): ...@@ -42,6 +42,18 @@ class ModelSelector(Gtk.Dialog):
box.pack_start(model_box, False, False, 5) box.pack_start(model_box, False, False, 5)
slu_box =Gtk.HBox()
slu_box.pack_start(Gtk.Label('SLU model:'), False, False, 5)
slu_chooser = Gtk.ComboBoxText()
slu_chooser.set_entry_text_column(0)
for i, slu in enumerate(['Regular', 'Alternative']):
slu_chooser.append_text(slu)
slu_chooser.set_active(int(self.options['slu']))
self.slu_chooser = slu_chooser
slu_box.pack_start(slu_chooser, True, True, 5)
box.pack_start(slu_box, False, False, 5)
osc_box = Gtk.HBox() osc_box = Gtk.HBox()
osc_box.pack_start(Gtk.Label('OSC host:'), False, False, 5) osc_box.pack_start(Gtk.Label('OSC host:'), False, False, 5)
osc_host = Gtk.Entry() osc_host = Gtk.Entry()
...@@ -123,7 +135,8 @@ class ModelSelector(Gtk.Dialog): ...@@ -123,7 +135,8 @@ class ModelSelector(Gtk.Dialog):
self.options['xml_filename'] = self.xml_entry.get_text() self.options['xml_filename'] = self.xml_entry.get_text()
self.options['osc_host'] = self.osc_host.get_text() self.options['osc_host'] = self.osc_host.get_text()
self.options['osc_port'] = self.osc_port.get_text() self.options['osc_port'] = self.osc_port.get_text()
self.options['slu'] = self.slu_chooser.get_active()
self.save_options() self.save_options()
self.destroy() self.destroy()
return self.options['xml_filename'], self.options['asr_model'], self.options['osc_host'], self.options['osc_port'] return self.options['xml_filename'], self.options['asr_model'], self.options['osc_host'], self.options['osc_port'], self.options['slu']
...@@ -86,11 +86,13 @@ class Line(Gtk.HBox): ...@@ -86,11 +86,13 @@ class Line(Gtk.HBox):
self.elements = elements self.elements = elements
self.get_style_context().add_class('text-line') self.get_style_context().add_class('text-line')
def highlight(self, active=True): def highlight(self, active=True, scrollable=None):
if active: if active:
self.label.get_style_context().add_class('highlighted') self.label.get_style_context().add_class('highlighted')
else: else:
self.label.get_style_context().remove_class('highlighted') self.label.get_style_context().remove_class('highlighted')
if scrollable != None:
animate.scroll_to(scrollable, self)
class Keyword(Gtk.Label): class Keyword(Gtk.Label):
def __init__(self, text, action, lang): def __init__(self, text, action, lang):
...@@ -102,9 +104,11 @@ class Keyword(Gtk.Label): ...@@ -102,9 +104,11 @@ class Keyword(Gtk.Label):
self.get_style_context().add_class('keyword') self.get_style_context().add_class('keyword')
self.connect('activate-link', self.link_clicked) self.connect('activate-link', self.link_clicked)
def highlight(self): def highlight(self, scrollable=None):
self.get_style_context().remove_class('keyword') self.get_style_context().remove_class('keyword')
self.get_style_context().add_class('keyword-highlighted') self.get_style_context().add_class('keyword-highlighted')
if scrollable != None:
animate.scroll_to(scrollable, self)
def link_clicked(self, widget, uri): def link_clicked(self, widget, uri):
actions.perform_action(actions.Action(uri, keyword=widget), False) actions.perform_action(actions.Action(uri, keyword=widget), False)
...@@ -171,12 +175,12 @@ class XmlView(Gtk.ScrolledWindow): ...@@ -171,12 +175,12 @@ class XmlView(Gtk.ScrolledWindow):
def highlight(self, action): def highlight(self, action):
if hasattr(action, 'keyword'): if hasattr(action, 'keyword'):
action.keyword.highlight() action.keyword.highlight(self)
else: else:
i = self.last_highlighted + 1 i = self.last_highlighted + 1
while i < len(self.keywords): while i < len(self.keywords):
if self.keywords[i].action == action.text: if self.keywords[i].action == action.text:
self.keywords[i].highlight() self.keywords[i].highlight(self)
self.last_highlighted = i self.last_highlighted = i
break break
i += 1 i += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment