Skip to content
Snippets Groups Projects
Commit 56635e22 authored by Benoit Favre's avatar Benoit Favre
Browse files

add support for SLU

parent fd15196f
No related branches found
No related tags found
No related merge requests found
Deps:
- gtk3 for python2 (package is python2-gobject in archlinux, as well as gtk3)
- liblo with python2 bindings (OSC library)
- get and compile https://github.com/alumae/gst-kaldi-nnet2-online (which requires kaldi)
Install:
......@@ -19,12 +20,17 @@ developing with pygtk3: http://lazka.github.io/pgi-docs/, https://python-gtk-3-t
Todo:
- model selection in user interface
- integrate new xml with actions
- account for custom acoustic scale and endpoint silence phones
- refactor:
events = click action or words to resynchronize ?
global action send which deals with
- show a warning (optional)
- sned action through osc
- show an action performed message (message log with timing?)
add global timer which shows elapsed time
change xml view to reflect already performed actions, already recognized text
configuration for osc
make configuration box persistant
/storage/raid1/homedirs/mickael.rouvier/raid2/kaldi_english/exp/nnet2_online/
--acoustic-scale=.04166666666666666666
--endpoint.silence-phones=1:2:3:4:5:6:7:8:9:10:11:12:13:14:15:16:17:18:19:20:21:22:23:24:25:26:27:28:29:30:31:32:33:34:35
- use GtkSourceView to allow editing the xml file directly
SLU.py 0 → 100644
from ctypes import *
_backend = None
class SLU:
#/src.new/rocio_slu -word "$prefix"_dico_word.txt -action "$prefix"_dico_action.txt -fstmodel "$prefix".fst -fstclean "$prefix"_clean_tail.fst
def __init__(self, word_lexicon, action_lexicon, model_fst, cleaner_fst, library='slu/src.new/librocio_slu.so'):
global _backend
if _backend == None:
_backend = cdll.LoadLibrary(library)
# slu_t* init_slu(char* chfileword, char* chfileaction, char* chfilemodel, char* chfileclean);
_backend.init_slu.argtypes = [c_char_p, c_char_p, c_char_p, c_char_p]
_backend.init_slu.restype = c_void_p
# int run_slu(slu_t* slu, char** words, int num_words, int prevn);
_backend.run_slu.argtypes = [c_void_p, POINTER(c_char_p), c_int, c_int]
_backend.run_slu.restype = c_int
# int num_actions(slu_t* slu)
_backend.num_actions.argtypes = [c_void_p]
_backend.num_actions.restype = c_int
# int get_action(slu_t* slu, int index)
_backend.get_action.argtypes = [c_void_p, c_int]
_backend.get_action.restype = c_char_p
# void free_slu(slu_t* slu);
_backend.free_slu.argtypes = [c_void_p]
_backend.free_slu.restype = None
self.slu = _backend.init_slu(word_lexicon, action_lexicon, model_fst, cleaner_fst)
def process(self, words):
c_words = (c_char_p * len(words))(*words)
return _backend.run_slu(self.slu, c_words, len(words), self.num_actions())
def num_actions(self):
return _backend.num_actions(self.slu)
def get_action(self, index):
return _backend.get_action(self.slu, index)
def shutdown(self):
_backend.free_slu(self.slu)
if __name__ == '__main__':
prefix = 'slu/automate/homeostasis_25nov_%s'
slu = SLU(prefix % 'dico_word.txt', prefix % 'dico_action.txt', prefix % 'section6.fst', prefix % 'clean_tail.fst')
print 'before'
slu.process(open('slu/homeostasis_25nov.asr/sect6.ref').read().strip().split())
print 'after'
slu.shutdown()
from gi.repository import GObject
timer = None
def cancel():
global timer
if timer:
GObject.source_remove(timer)
LINEAR=1
DECELERATE=2
def animate_value(callback, current, target, policy=DECELERATE):
global timer
if current != target:
if policy == DECELERATE:
delta = abs(target - current) / 2
else:
delta = 32
if current > target:
current -= delta
else:
current += delta
if abs(current - target) < 2:
current = target
callback(current)
timer = GObject.timeout_add(1000 / 25, lambda: animate_value(callback, current, target))
def scroll_to(scrollview, widget):
result = widget.translate_coordinates(scrollview.get_child().get_child(), 0, 0)
if result:
adj = scrollview.get_vadjustment()
animate_value(adj.set_value, adj.get_value(), result[1])
......@@ -105,6 +105,7 @@ class ASR(Gtk.HBox):
self.button.set_sensitive(True)
def _on_partial_result(self, asr, hyp):
print 'PARTIAL', self.hyp, hyp
"""Delete any previous selection, insert text and select it."""
Gdk.threads_enter()
if len(self.hyp) == 0:
......@@ -121,16 +122,17 @@ class ASR(Gtk.HBox):
Gdk.threads_leave()
def _on_final_result(self, asr, hyp):
print 'FINAL', self.hyp, hyp
Gdk.threads_enter()
if len(self.hyp) == 0:
self.hyp = ['']
self.hyp[-1] = hyp
#self.hyp[-1] = hyp
if self.hyp_callback:
self.hyp_callback(self.hyp)
self.insert = self.buffer.get_iter_at_line(self.buffer.get_line_count() - 1)
self.buffer.delete(self.insert, self.buffer.get_end_iter())
self.buffer.insert(self.insert, hyp + '\n...')
self.buffer.insert(self.insert, self.hyp[-1] + '\n...')
self.hyp.append('')
Gdk.threads_leave()
......
......@@ -5,12 +5,21 @@
.text-line {
}
.keyword {
}
.text {
}
.highlighted {
color: red;
font: bold;
}
.selected-section-title {
font: bold 18;
color: white;
background: #009900;
}
.section-title {
font: bold 18;
color: white;
......@@ -19,6 +28,12 @@
.section-body {
background: white;
margin: 20px;
}
.selected-section-body {
background: #eeffee;
margin: 20px;
}
.sequence-title {
......
main.py 100755 → 100644
......@@ -2,6 +2,7 @@
import sys
import os
import glob, re
# set to location of libgstkaldionline2.so
os.environ['GST_PLUGIN_PATH'] = './asr/'
......@@ -21,8 +22,8 @@ import signal
signal.signal(signal.SIGINT, signal.SIG_DFL)
# import local stuff
import confirm, asr, action, section, xmlview_widgets
import levenstein
import confirm, asr, action, xmlview_widgets
import levenstein, SLU
class ScriptedASR(Gtk.Window):
def __init__(self, xml_filename, asr_config_file):
......@@ -34,40 +35,52 @@ class ScriptedASR(Gtk.Window):
self.set_title('ScriptedASR [%s]' % xml_filename)
vbox = Gtk.VBox()
#self.sections = section.SectionManager()
#vbox.pack_start(self.sections, False, True, 5)
import section
self.sections = section.SectionManager()
vbox.pack_start(self.sections, False, True, 5)
self.xmlview = xmlview_widgets.XmlView(xml_filename)
vbox.pack_start(self.xmlview, True, True, 5)
self.lines = [x for x in self.xmlview.get_line_iterator()]
self.current_line = -1
self.xmlview.set_action_clicked_handler(self.keyword_clicked)
#self.confirmer = confirm.ConfirmationBox()
#vbox.pack_start(self.confirmer, False, True, 5)
self.confirmer = confirm.ConfirmationBox()
vbox.pack_start(self.confirmer, False, True, 5)
#self.actions = action.ActionView()
#vbox.pack_start(self.actions, False, True, 5)
#self.sections.set_confirmer(self.confirmer)
self.sections.set_confirmer(self.confirmer)
#self.actions.set_confirmer(self.confirmer)
self.sections.set_highlight(self.xmlview)
self.sections.set_section(1)
# transcript view
self.asr = asr.ASR(asr_config_file, self.hyp_changed)
self.asr = asr.ASR(asr_config_file, self.hyp_changed, self.hyp_changed)
vbox.pack_start(self.asr, False, True, 5)
# slu
prefix = 'slu/automate/homeostasis_25nov_%s'
self.slu = {}
for section_fst in glob.glob(prefix % 'section*.fst'):
found = re.search('section(\d+)\.fst$', section_fst)
if found:
print section_fst
section_id = int(found.group(1))
self.slu[section_id] = SLU.SLU(prefix % 'dico_word.txt', prefix % 'dico_action.txt', section_fst, prefix % 'clean_tail.fst')
self.add(vbox)
self.show_all()
#self.confirmer.hide()
self.confirmer.hide()
# load css style
style_provider = Gtk.CssProvider()
style_provider.load_from_data(open('data/style.css', 'rb').read())
Gtk.StyleContext.add_provider_for_screen( Gdk.Screen.get_default(), style_provider, Gtk.STYLE_PROVIDER_PRIORITY_APPLICATION)
for line in self.lines:
line.set_click_handler(self.line_clicked)
def line_clicked(self, widget, event):
if self.current_line >= 0:
self.lines[self.current_line].highlight(False)
......@@ -76,29 +89,45 @@ class ScriptedASR(Gtk.Window):
self.current_line = i
self.lines[self.current_line].highlight(True)
def keyword_clicked(self, action):
self.confirmer.confirm('Force action \"%s\"?' % action, 3, None)
def hyp_changed(self, hyp):
#hyp = ' '.join(hyp).replace('[noise]', ' ').split()
words = hyp[-1].strip().replace('_', ' ').split()
if self.current_line >= len(self.lines) - 1:
print "FINISHED"
return
line = self.lines[self.current_line + 1].text.split()
import levenstein
num_errors, num_ref, alignment, score = levenstein.align(line, words)
num_matches = 0
for ref_word, hyp_word in alignment:
if ref_word == hyp_word and ref_word != None:
num_matches += 1
score = float(num_matches) / max(len(line), len(words))
print 'ASR:', hyp[-1], 'REF:', self.lines[self.current_line + 1].text, 'score:', score
levenstein.print_alignment(alignment)
if score >= 0.5:
if self.current_line >= 0:
self.lines[self.current_line].highlight(False)
self.current_line += 1
self.lines[self.current_line].highlight(True)
section_id = self.sections.get_section()
print section_id, self.slu.keys()
if section_id in self.slu:
print 'SLU: ', section_id
slu = self.slu[section_id]
previous_actions = slu.num_actions()
slu.process(words)
for action_id in range(previous_actions, slu.num_actions()):
action = slu.get_action(action_id)
self.confirmer.confirm('Perform action \"%s\"?' % action, 3, None)
#if self.current_line >= len(self.lines) - 1:
# print "FINISHED"
# return
#line = self.lines[self.current_line + 1].text.split()
#import levenstein
#num_errors, num_ref, alignment, score = levenstein.align(line, words)
#num_matches = 0
#for ref_word, hyp_word in alignment:
# if ref_word == hyp_word and ref_word != None:
# num_matches += 1
#score = float(num_matches) / max(len(line), len(words))
#print 'ASR:', hyp[-1], 'REF:', self.lines[self.current_line + 1].text, 'score:', score
#levenstein.print_alignment(alignment)
#if score >= 0.5:
# if self.current_line >= 0:
# self.lines[self.current_line].highlight(False)
# self.current_line += 1
# self.lines[self.current_line].highlight(True)
def quit(self, window):
for slu in self.slu.values():
slu.shutdown()
Gtk.main_quit()
......@@ -106,6 +135,7 @@ if __name__ == '__main__':
import selector
xml_filename = 'data/homeostasis_25nov.xml'
asr_config_file = 'asr/mika-fred-1.cfg'
asr_config_file = 'asr/fisher-benoit-1.cfg'
if len(sys.argv) > 1:
xml_filename = sys.argv[1]
if len(sys.argv) > 2:
......
from gi.repository import Gtk
class DefaultConfirmer:
def confirm(self, *args):
pass
class SectionManager(Gtk.HBox):
def __init__(self):
super(SectionManager, self).__init__()
......@@ -13,16 +17,23 @@ class SectionManager(Gtk.HBox):
self.pack_start(self.label, True, True, 5)
self.pack_start(self.button_next, False, False, 5)
self.set_section(1)
self.confirmer = None
self.confirmer = DefaultConfirmer()
self.button_next.connect('clicked', lambda widget: self.confirmer.confirm('Go to next section?', 3, self.next_section))
self.button_prev.connect('clicked', lambda widget: self.confirmer.confirm('Go to previous section?', 3, self.previous_section))
self.highlight = None
def set_section(self, section):
self.section = section
self.label.set_text('Current section: %d' % self.section)
if self.highlight:
self.highlight.highlight_section(section)
def set_highlight(self, highlight):
self.highlight = highlight
def get_view(self):
return self.hbox
......
......@@ -5,6 +5,7 @@ import config
class ModelSelector(Gtk.Dialog):
def __init__(self, xml_filename = '', asr_model = ''):
super(ModelSelector, self).__init__()
self.set_title('Configuration')
self.add_button("Cancel", Gtk.ResponseType.CANCEL)
self.add_button("OK", Gtk.ResponseType.OK)
box = self.get_content_area()
......@@ -26,13 +27,36 @@ class ModelSelector(Gtk.Dialog):
model_box.pack_start(Gtk.Label('ASR model:'), False, False, 10)
model_chooser = Gtk.ComboBoxText()
model_chooser.set_entry_text_column(0)
target_index = 0
for i, model in enumerate(self.list_models()):
model_chooser.append_text(model)
model_chooser.set_active(0)
if asr_model == self.models[i]:
target_index = i
model_chooser.set_active(target_index)
self.model_chooser = model_chooser
model_box.pack_start(model_chooser, True, True, 10)
box.pack_start(model_box, False, False, 5)
osc_box = Gtk.HBox()
osc_box.pack_start(Gtk.Label('OSC host:'), False, False, 10)
osc_host = Gtk.Entry()
osc_host.set_text('127.0.0.1')
osc_host.set_width_chars(len(osc_host.get_text()))
osc_box.pack_start(osc_host, True, True, 10)
osc_box.pack_start(Gtk.Label('Port:'), False, False, 10)
osc_port = Gtk.Entry()
osc_port.set_text('1234')
osc_port.set_width_chars(len(osc_port.get_text()))
osc_box.pack_start(osc_port, True, True, 10)
box.pack_start(osc_box, False, False, 5)
self.xml_entry.set_activates_default(True)
okButton = self.get_widget_for_response(response_id=Gtk.ResponseType.OK)
okButton.set_can_default(True)
okButton.grab_default()
self.show_all()
def show_filechooser(self, button):
......
#!/usr/bin/env python2
import sys
import os
# set to location of libgstkaldionline2.so
os.environ['GST_PLUGIN_PATH'] = './asr/'
os.environ['GTK_THEME'] = 'light'
# import gtk stuff
from threading import Thread
import gi
gi.require_version('Gst', '1.0')
from gi.repository import GObject, Gst, Gtk, Gdk, GtkSource
GObject.threads_init()
Gdk.threads_init()
Gst.init(None)
# make sure ctrl-c works
import signal
signal.signal(signal.SIGINT, signal.SIG_DFL)
# import local stuff
import asr
class SourceView(Gtk.Window):
def __init__(self, asr_config_file):
super(SourceView, self).__init__()
self.connect("destroy", self.quit)
self.set_default_size(800,600)
self.set_border_width(10)
self.set_title('SourceView')
vbox = Gtk.VBox()
self.source_buffer = GtkSource.Buffer()
lang_manager = GtkSource.LanguageManager()
self.source_buffer.set_language(lang_manager.get_language('python'))
self.source_buffer.set_text(open(__file__).read())
self.source_view = GtkSource.View.new_with_buffer(self.source_buffer)
self.source_view.set_show_line_numbers(True)
self.source_view.set_tab_width(4)
self.scrolled = Gtk.ScrolledWindow()
self.scrolled.add(self.source_view)
vbox.pack_start(self.scrolled, True, True, 5)
self.asr = asr.ASR(asr_config_file, self.hyp_changed)
vbox.pack_start(self.asr, False, True, 5)
self.add(vbox)
self.show_all()
# load css style
style_provider = Gtk.CssProvider()
style_provider.load_from_data(open('data/style.css', 'rb').read())
Gtk.StyleContext.add_provider_for_screen( Gdk.Screen.get_default(), style_provider, Gtk.STYLE_PROVIDER_PRIORITY_APPLICATION)
def hyp_changed(self, hyp):
pass
def quit(self, window):
Gtk.main_quit()
if __name__ == '__main__':
asr_config_file = 'asr/fisher-benoit-1.cfg'
app = SourceView(asr_config_file)
Gtk.main()
import animate
from gi.repository import GObject, Gtk, Pango, Gdk
from xml.etree import ElementTree as ET
class Section(Gtk.VBox):
def __init__(self, section):
super(Section, self).__init__()
self.name = section.get('id')
self.get_style_context().add_class('section-body')
self.title = Gtk.EventBox()
self.label = Gtk.Label('Section %s' % self.name)
self.label.get_style_context().add_class('section-title')
self.pack_start(self.label, True, True, 5)
self.title.add(self.label)
self.title.connect('button-press-event', self.clicked)
self.handler = None
#cursor = Gdk.Cursor(Gdk.CursorType.HAND1)
#self.label.get_window().set_cursor(cursor)
self.pack_start(self.title, True, True, 5)
self.sequences = []
num = 1
......@@ -18,6 +25,26 @@ class Section(Gtk.VBox):
self.pack_start(self.sequences[-1], True, True, 5)
num += 1
def highlight(self, active=True):
if active:
self.label.get_style_context().remove_class('section-title')
self.label.get_style_context().add_class('selected-section-title')
self.get_style_context().add_class('selected-section-body')
self.get_style_context().remove_class('section-body')
else:
self.label.get_style_context().remove_class('selected-section-title')
self.label.get_style_context().add_class('section-title')
self.get_style_context().remove_class('selected-section-body')
self.get_style_context().add_class('section-body')
def clicked(self, widget, event):
if self.handler:
self.handler(int(self.name) - 1)
def set_handler(self, handler):
self.handler = handler
class Sequence(Gtk.VBox):
def __init__(self, sequence, name):
super(Sequence, self).__init__()
......@@ -28,42 +55,79 @@ class Sequence(Gtk.VBox):
self.pack_start(self.label, True, True, 5)
self.lines = []
text = str(sequence.text)
elements = []
for line in sequence.text.split('\n'):
line = line.strip()
if line != '':
elements.append(Text(line))
if len(elements) > 0:
self.lines.append(Line(elements))
self.pack_start(self.lines[-1], True, True, 5)
elements = []
for node in sequence:
text += node.text
text += node.tail
for line in text.split('\n'):
if node.tag == 'keyword':
text = str(node.text).strip()
if node.get('action').strip() != '':
elements.append(Keyword(text, node.get('action'), node.get('lang')))
else:
elements.append(Text(text))
for line in node.tail.split('\n'):
line = line.strip()
if line != '':
self.lines.append(Line(line))
elements.append(Text(line))
if len(elements) > 0:
self.lines.append(Line(elements))
self.pack_start(self.lines[-1], True, True, 5)
elements = []
if len(elements) > 0:
self.lines.append(Line(elements))
self.pack_start(self.lines[-1], True, True, 5)
class Line(Gtk.EventBox):
def __init__(self, text):
class Line(Gtk.HBox):
def __init__(self, elements):
super(Line, self).__init__()
self.text = text
self.label = Gtk.Label()
self.label.set_text(' ' + text)
self.label.set_halign(Gtk.Align.START)
self.label.get_style_context().add_class('text-line')
self.add(self.label)
def set_click_handler(self, handler):
self.connect('button-press-event', handler)
cursor = Gdk.Cursor(Gdk.CursorType.HAND1)
self.get_window().set_cursor(cursor)
def highlight(self, highlighted=True):
if highlighted:
self.pack_start(Gtk.Label(' '), False, False, 0)
for element in elements:
self.pack_start(element, False, False, 0)
self.elements = elements
self.get_style_context().add_class('text-line')
def set_handler(self, handler):
for element in self.elements:
if hasattr(element, 'set_handler'):
element.set_handler(handler)
def highlight(self, active=True):
if active:
self.label.get_style_context().add_class('highlighted')
else:
self.label.get_style_context().remove_class('highlighted')
class Word:
def __init__(self, text, start, end):
self.text = text
self.start = start
self.end = end
class Keyword(Gtk.Label):
def __init__(self, text, action, lang):
super(Keyword, self).__init__()
self.action = action
self.lang = lang
text = '\n'.join([x.strip() for x in text.split('\n')])
self.set_markup(text + ' [<a href="%s">%s</a>] ' % (action, action))
self.get_style_context().add_class('keyword')
self.connect('activate-link', self.link_clicked)
self.handler = None
def set_handler(self, handler):
self.handler = handler
def link_clicked(self, widget, uri):
if self.handler:
self.handler(uri)
return True
class Text(Gtk.Label):
def __init__(self, text):
super(Text, self).__init__()
text = '\n'.join([x.strip() for x in text.split('\n')])
self.set_text(text + ' ')
self.get_style_context().add_class('text')
class XmlView(Gtk.ScrolledWindow):
def __init__(self, filename):
......@@ -73,7 +137,8 @@ class XmlView(Gtk.ScrolledWindow):
self.set_policy(Gtk.PolicyType.NEVER, Gtk.PolicyType.ALWAYS)
self.add_with_viewport(self.parse_xml(filename))
self.vbox = self.parse_xml(filename)
self.add_with_viewport(self.vbox)
self.last_section = None
......@@ -90,9 +155,21 @@ class XmlView(Gtk.ScrolledWindow):
vbox.pack_start(self.sections[-1], True, True, 5)
return vbox
def set_action_clicked_handler(self, handler):
for line in self.get_line_iterator():
line.set_handler(handler)
def get_line_iterator(self):
for section in self.sections:
for sequence in section.sequences:
for line in sequence.lines:
yield line
def highlight_section(self, section):
if section < 1 or section > len(self.sections):
print "invalid section", section
else:
for current in range(len(self.sections)):
self.sections[current].highlight(current == section - 1)
animate.scroll_to(self, self.sections[section - 1])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment