Skip to content
Snippets Groups Projects
Commit 234a6d88 authored by Benoit Favre's avatar Benoit Favre
Browse files

add alternate slu

parent ff0185bd
No related branches found
No related tags found
No related merge requests found
slu: 1
xml_filename: data/homeostasis_25nov.xml
osc_host: 127.0.0.1
osc_port: 1234
asr_model: asr/fisher-benoit-1.cfg
asr_model: asr/mika-fred-2.cfg
from xml.etree import ElementTree as ET
from collections import defaultdict
history_length = 5
_index = defaultdict(list)
class Keyword:
def __init__(self, num, history, word, action, section):
self.history = history
self.word = word
self.num = num
self.action = action
self.section = int(section) - 1
for word in history:
_index[word].append(num)
class SLU:
def __init__(self, xml_filename):
self.words = []
self.parse_xml(xml_filename)
def parse_xml(self, filename):
root = ET.parse(filename)
for section in root.findall(".//section"):
history = [x for x in reversed(range(history_length))]
for sequence in section.findall('./sequence'):
history.extend(sequence.text.strip().split())
for node in sequence:
if node.tag == 'keyword':
history.extend(node.text.strip().split())
if node.get('action').strip() != '':
self.words.append(Keyword(len(self.words), history[-history_length: -1], history[-1], node.get('action'), section.get("id")))
history.extend(node.tail.strip().split())
def process(self, history, target, expected=-1, section=-1):
history = ([x for x in reversed(range(history_length))] + history)[-history_length:]
found = set()
for word in history:
found.update(_index[word])
#print len(found)
history = set(history)
min_distance = None
max_value = 0
argmax = None
for word in [self.words[num] for num in found]:
if section != -1 and word.section != section:
continue
history_score = len(history & set(word.history)) / float(history_length)
target_score = len(set(word.word) & set(target)) / float(len(target))
if history_score == 0 or target_score == 0:
continue
value = history_score + target_score
if value > max_value:
max_value = value
argmax = word
if expected != -1 and value == max_value and (min_distance == None or abs(word.num - expected) < min_distance):
min_distance = abs(word.num - expected)
max_value = value
argmax = word
return argmax
def last_in_section(self, word):
if word.num == len(self.words) - 1 or self.words[word.num + 1].section != word.section:
return True
return False
if __name__ == '__main__':
slu = SLU('data/homeostasis_25nov.xml')
history = 'open technical'.split()
word = 'characteristics'
found = slu.process(history, word)
print found.word, found.section, found.action, found.history, word
......@@ -31,4 +31,4 @@ def scroll_to(scrollview, widget):
result = widget.translate_coordinates(scrollview.get_child().get_child(), 0, 0)
if result:
adj = scrollview.get_vadjustment()
animate_value(adj.set_value, adj.get_value(), result[1])
animate_value(adj.set_value, adj.get_value(), result[1] - 150)
......@@ -11,7 +11,7 @@ Gst.init(None)
import osc
class ASR(Gtk.HBox):
def __init__(self, asr_config, hyp_callback = None, partial_hyp_callback = None):
def __init__(self, asr_config, partial_hyp_callback = None, hyp_callback = None):
super(ASR, self).__init__()
self.text = Gtk.TextView()
......
......@@ -28,13 +28,13 @@ import confirm, asr, actions, xmlview
import levenstein, slu, osc
class ScriptedASR(Gtk.Window):
def __init__(self, xml_filename, asr_config_file, osc_host, osc_port):
def __init__(self, xml_filename, asr_config_file, osc_host, osc_port, slu_type):
super(ScriptedASR, self).__init__()
self.connect("destroy", self.quit)
self.set_default_size(1024, 768)
self.set_border_width(10)
self.set_title('ASR Transcript [xml=%s asr=%s osc=%s:%s]' % (xml_filename, asr_config_file, osc_host, osc_port))
self.set_title('ASR Transcript [xml=%s asr=%s osc=%s:%s slu=%d]' % (xml_filename, asr_config_file, osc_host, osc_port, slu_type))
vbox = Gtk.VBox()
self.xmlview = xmlview.XmlView(xml_filename)
......@@ -44,7 +44,10 @@ class ScriptedASR(Gtk.Window):
vbox.pack_start(self.confirmer, False, True, 5)
# transcript view
self.asr = asr.ASR(asr_config_file, self.hyp_changed)
if slu_type == 1:
self.asr = asr.ASR(asr_config_file, self.hyp_changed2)
else:
self.asr = asr.ASR(asr_config_file, self.hyp_changed)
vbox.pack_start(self.asr, False, True, 5)
# slu
......@@ -59,6 +62,13 @@ class ScriptedASR(Gtk.Window):
section_id = int(found.group(1))
self.slu[section_id - 1] = slu.SLU(prefix % 'dico_word.txt', prefix % 'dico_action.txt', section_fst, prefix % 'clean_tail.fst', library=library)
import alternate_slu
self.slu2 = alternate_slu.SLU(xml_filename)
self.slu2_expected = 0
self.slu2_performed = set()
self.slu2_last_section = -1
self.slu2_history = []
self.add(vbox)
self.show_all()
......@@ -162,6 +172,24 @@ class ScriptedASR(Gtk.Window):
model.process(words, self.kept_history, self.slu_finished, False)
self.current_section_history[-1] = ' '.join(words)
def hyp_changed2(self, hyp):
section_id = self.xmlview.get_section()
if self.slu2_last_section != section_id:
self.slu2_last_section = section_id
self.slu2_performed = set()
self.slu2_history = []
self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split())
found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id)
if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]:
self.slu2_performed.add(found.num)
print found.section, found.action, found.history, found.word
found.keyword = self.xmlview.keywords[found.num]
self.xmlview.highlight(found)
self.slu2_expected = found.num + 1
if self.slu2.last_in_section(found):
new_section = self.xmlview.get_section() + 1
self.confirmer.confirm('Go to section %d?' % (new_section + 1), 3, lambda: self.set_section(new_section))
def quit(self, window):
for slu in self.slu.values():
slu.shutdown()
......@@ -178,8 +206,8 @@ if __name__ == '__main__':
xml_filename = sys.argv[1]
if len(sys.argv) > 2:
asr_config_file = sys.argv[2]
xml_filename, asr_config_file, osc_host, osc_port = selector.ModelSelector(xml_filename, asr_config_file).run()
if xml_filename == None or asr_config_file == None or osc_host == None or osc_port == None:
xml_filename, asr_config_file, osc_host, osc_port, slu_type = selector.ModelSelector(xml_filename, asr_config_file).run()
if xml_filename == None or asr_config_file == None or osc_host == None or osc_port == None or slu_type == None:
sys.exit(0)
app = ScriptedASR(xml_filename, asr_config_file, osc_host, osc_port)
app = ScriptedASR(xml_filename, asr_config_file, osc_host, osc_port, slu_type)
Gtk.main()
......@@ -5,7 +5,7 @@ import config
class ModelSelector(Gtk.Dialog):
def __init__(self, xml_filename = '', asr_model = ''):
super(ModelSelector, self).__init__()
self.options = {'xml_filename': xml_filename, 'asr_model': asr_model, 'osc_host': '127.0.0.1', 'osc_port': '1234'}
self.options = {'xml_filename': xml_filename, 'asr_model': asr_model, 'osc_host': '127.0.0.1', 'osc_port': '1234', 'slu': 'Regular'}
self.load_options()
self.set_title('Configuration')
......@@ -42,6 +42,18 @@ class ModelSelector(Gtk.Dialog):
box.pack_start(model_box, False, False, 5)
slu_box =Gtk.HBox()
slu_box.pack_start(Gtk.Label('SLU model:'), False, False, 5)
slu_chooser = Gtk.ComboBoxText()
slu_chooser.set_entry_text_column(0)
for i, slu in enumerate(['Regular', 'Alternative']):
slu_chooser.append_text(slu)
slu_chooser.set_active(int(self.options['slu']))
self.slu_chooser = slu_chooser
slu_box.pack_start(slu_chooser, True, True, 5)
box.pack_start(slu_box, False, False, 5)
osc_box = Gtk.HBox()
osc_box.pack_start(Gtk.Label('OSC host:'), False, False, 5)
osc_host = Gtk.Entry()
......@@ -123,7 +135,8 @@ class ModelSelector(Gtk.Dialog):
self.options['xml_filename'] = self.xml_entry.get_text()
self.options['osc_host'] = self.osc_host.get_text()
self.options['osc_port'] = self.osc_port.get_text()
self.options['slu'] = self.slu_chooser.get_active()
self.save_options()
self.destroy()
return self.options['xml_filename'], self.options['asr_model'], self.options['osc_host'], self.options['osc_port']
return self.options['xml_filename'], self.options['asr_model'], self.options['osc_host'], self.options['osc_port'], self.options['slu']
......@@ -86,11 +86,13 @@ class Line(Gtk.HBox):
self.elements = elements
self.get_style_context().add_class('text-line')
def highlight(self, active=True):
def highlight(self, active=True, scrollable=None):
if active:
self.label.get_style_context().add_class('highlighted')
else:
self.label.get_style_context().remove_class('highlighted')
if scrollable != None:
animate.scroll_to(scrollable, self)
class Keyword(Gtk.Label):
def __init__(self, text, action, lang):
......@@ -102,9 +104,11 @@ class Keyword(Gtk.Label):
self.get_style_context().add_class('keyword')
self.connect('activate-link', self.link_clicked)
def highlight(self):
def highlight(self, scrollable=None):
self.get_style_context().remove_class('keyword')
self.get_style_context().add_class('keyword-highlighted')
if scrollable != None:
animate.scroll_to(scrollable, self)
def link_clicked(self, widget, uri):
actions.perform_action(actions.Action(uri, keyword=widget), False)
......@@ -171,12 +175,12 @@ class XmlView(Gtk.ScrolledWindow):
def highlight(self, action):
if hasattr(action, 'keyword'):
action.keyword.highlight()
action.keyword.highlight(self)
else:
i = self.last_highlighted + 1
while i < len(self.keywords):
if self.keywords[i].action == action.text:
self.keywords[i].highlight()
self.keywords[i].highlight(self)
self.last_highlighted = i
break
i += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment