Skip to content
Snippets Groups Projects
asr.py 6.43 KiB
Newer Older
Benoit Favre's avatar
Benoit Favre committed
import os, sys

from threading import Thread
import gi
gi.require_version('Gst', '1.0')
from gi.repository import GObject, Gst, Gtk, Gdk
GObject.threads_init()
Gdk.threads_init()
Gst.init(None)

class ASR(Gtk.HBox):
    def __init__(self, asr_model_dir, hyp_callback = None, partial_hyp_callback = None):
Benoit Favre's avatar
Benoit Favre committed
        super(ASR, self).__init__()

        self.text = Gtk.TextView()
        self.text.set_editable(False)
        self.text.set_cursor_visible(False)
        self.buffer = self.text.get_buffer()
        self.text.set_wrap_mode(Gtk.WrapMode.WORD)

        self.scrolled = Gtk.ScrolledWindow()
        self.scrolled.set_policy(Gtk.PolicyType.NEVER, Gtk.PolicyType.ALWAYS)
        self.scrolled.add_with_viewport(self.text)
        self.scrolled.set_size_request(-1, 100)
        self.pack_start(self.scrolled, True, True, 5)

        self.button = Gtk.Button("Speak")
        self.button.set_sensitive(False)
        self.pack_start(self.button, False, False, 5)

        self.button.connect('clicked', self.button_clicked)
        self.text.connect("size-allocate", self.autoscroll)

        self.hyp = []
        self.hyp_callback = hyp_callback
        self.partial_hyp_callback = partial_hyp_callback
        Thread(target=self.init_gst, args=[asr_model_dir]).start()
    def init_gst(self, model='.'):
Benoit Favre's avatar
Benoit Favre committed
        """Initialize the speech components"""
        GObject.idle_add(self._started_loading_asr)

        self.pulsesrc = Gst.ElementFactory.make("pulsesrc", "pulsesrc")
        if self.pulsesrc == None:
            print >> sys.stderr, "Error loading pulsesrc GST plugin. You probably need the gstreamer1.0-pulseaudio package"
            sys.exit()	
        self.audioconvert = Gst.ElementFactory.make("audioconvert", "audioconvert")
        self.audioresample = Gst.ElementFactory.make("audioresample", "audioresample")    
        self.asr = Gst.ElementFactory.make("kaldinnet2onlinedecoder", "asr")
        self.fakesink = Gst.ElementFactory.make("fakesink", "fakesink")
        
        if self.asr:
          model_file = "%s/final.mdl" % model
Benoit Favre's avatar
Benoit Favre committed
          if not os.path.isfile(model_file):
              print >> sys.stderr, "Models not downloaded? Run prepare-models.sh first!"
              sys.exit(1)
          self.asr.set_property("fst", "%s/HCLG.fst" % model)
          self.asr.set_property("model", "%s/final.mdl" % model)
          self.asr.set_property("word-syms", "%s/words.txt" % model)
          self.asr.set_property("acoustic-scale", 0.0416)
Benoit Favre's avatar
Benoit Favre committed
          self.asr.set_property("feature-type", "mfcc")
          self.asr.set_property("mfcc-config", "%s/conf/mfcc.conf" % model)
          self.asr.set_property("ivector-extraction-config", "%s/conf/ivector_extractor.fixed.conf" % model)
Benoit Favre's avatar
Benoit Favre committed
          self.asr.set_property("max-active", 7000) # 7000
          self.asr.set_property("beam", 11.0) # 11
          self.asr.set_property("lattice-beam", 6.0) # 6
          self.asr.set_property("do-endpointing", True)
          self.asr.set_property("endpoint-silence-phones", "1:2:3:4:5:6:7:8:9:10")
        else:
          print >> sys.stderr, "Couldn't create the kaldinnet2onlinedecoder element. "
          if os.environ.has_key("GST_PLUGIN_PATH"):
            print >> sys.stderr, "Have you compiled the Kaldi GStreamer plugin?"
          else:
            print >> sys.stderr, "You probably need to set the GST_PLUGIN_PATH envoronment variable"
            print >> sys.stderr, "Try running: GST_PLUGIN_PATH=../src %s" % sys.argv[0]
          sys.exit();
        
        # initially silence the decoder
        self.asr.set_property("silent", True)
        
        self.pipeline = Gst.Pipeline()
        for element in [self.pulsesrc, self.audioconvert, self.audioresample, self.asr, self.fakesink]:
            self.pipeline.add(element)         
        self.pulsesrc.link(self.audioconvert)
        self.audioconvert.link(self.audioresample)
        self.audioresample.link(self.asr)
        self.asr.link(self.fakesink)    
  
        self.asr.connect('partial-result', self._on_partial_result)
        self.asr.connect('final-result', self._on_final_result)
        self.pipeline.set_state(Gst.State.PLAYING)

        GObject.idle_add(self._finished_loading_asr)

    def load_state(self):
        try:
            with open('state.txt') as fp:
                self.asr.set_property('adaptation-state', fp.read())
        except:
          print >> sys.stderr, 'failed to load asr state'

    def save_state(self):
        if hasattr(self, 'asr'):
            state = self.asr.get_property('adaptation-state')
            try:
                with open('state.txt', 'w') as fp:
                    fp.write(state)
            except:
                print >> sys.stderr, 'failed to save asr state'

    def _started_loading_asr(self):
        self.button.set_sensitive(False)
        self.button.set_label("Loading...")

    def _finished_loading_asr(self):
        self.button.set_label("Speak")
        self.button.set_sensitive(True)

    def _on_partial_result(self, asr, hyp):
        """Delete any previous selection, insert text and select it."""
        Gdk.threads_enter()
        if len(self.hyp) == 0:
            self.hyp = ['']
        self.hyp[-1] = hyp
        if self.partial_hyp_callback:
            self.partial_hyp_callback(self.hyp)
Benoit Favre's avatar
Benoit Favre committed

        hyp += '...'
        self.insert = self.buffer.get_iter_at_line(self.buffer.get_line_count() - 1)
        self.buffer.delete(self.insert, self.buffer.get_end_iter())
        self.buffer.insert(self.insert, hyp)

        Gdk.threads_leave()
                
    def _on_final_result(self, asr, hyp):
        Gdk.threads_enter()
        if len(self.hyp) == 0:
            self.hyp = ['']
        self.hyp[-1] = hyp
        if self.hyp_callback:
            self.hyp_callback(self.hyp)

        self.insert = self.buffer.get_iter_at_line(self.buffer.get_line_count() - 1)
        self.buffer.delete(self.insert, self.buffer.get_end_iter())
        self.buffer.insert(self.insert, hyp + '\n...')

        self.hyp.append('')
Benoit Favre's avatar
Benoit Favre committed
        Gdk.threads_leave()

    def autoscroll(self, *args):
        adj = self.scrolled.get_vadjustment()
        adj.set_value(adj.get_upper() - adj.get_page_size())

    def button_clicked(self, button):
        """Handle button presses."""
        if button.get_label() == "Speak":
            button.set_label("Stop")
            self.asr.set_property("silent", False)
            self.hyp = []
            self.buffer.set_text('...')
        else:
            button.set_label("Speak")
            self.asr.set_property("silent", True)