Skip to content
Snippets Groups Projects
asr.py 6.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • Benoit Favre's avatar
    Benoit Favre committed
    import os, sys
    
    from threading import Thread
    import gi
    gi.require_version('Gst', '1.0')
    from gi.repository import GObject, Gst, Gtk, Gdk
    GObject.threads_init()
    Gdk.threads_init()
    Gst.init(None)
    
    class ASR(Gtk.HBox):
    
        def __init__(self, hyp_callback = None, partial_hyp_callback = None):
    
    Benoit Favre's avatar
    Benoit Favre committed
            super(ASR, self).__init__()
    
            self.text = Gtk.TextView()
            self.text.set_editable(False)
            self.text.set_cursor_visible(False)
            self.buffer = self.text.get_buffer()
            self.text.set_wrap_mode(Gtk.WrapMode.WORD)
    
            self.scrolled = Gtk.ScrolledWindow()
            self.scrolled.set_policy(Gtk.PolicyType.NEVER, Gtk.PolicyType.ALWAYS)
            self.scrolled.add_with_viewport(self.text)
            self.scrolled.set_size_request(-1, 100)
            self.pack_start(self.scrolled, True, True, 5)
    
            self.button = Gtk.Button("Speak")
            self.button.set_sensitive(False)
            self.pack_start(self.button, False, False, 5)
    
            self.button.connect('clicked', self.button_clicked)
            self.text.connect("size-allocate", self.autoscroll)
    
            self.hyp = []
            self.hyp_callback = hyp_callback
    
            self.partial_hyp_callback = partial_hyp_callback
    
    Benoit Favre's avatar
    Benoit Favre committed
            Thread(target=self.init_gst).start()
    
        def init_gst(self):
            """Initialize the speech components"""
            GObject.idle_add(self._started_loading_asr)
    
            self.pulsesrc = Gst.ElementFactory.make("pulsesrc", "pulsesrc")
            if self.pulsesrc == None:
                print >> sys.stderr, "Error loading pulsesrc GST plugin. You probably need the gstreamer1.0-pulseaudio package"
                sys.exit()	
            self.audioconvert = Gst.ElementFactory.make("audioconvert", "audioconvert")
            self.audioresample = Gst.ElementFactory.make("audioresample", "audioresample")    
            self.asr = Gst.ElementFactory.make("kaldinnet2onlinedecoder", "asr")
            self.fakesink = Gst.ElementFactory.make("fakesink", "fakesink")
            
            if self.asr:
              model_file = "asr/final.mdl"
              if not os.path.isfile(model_file):
                  print >> sys.stderr, "Models not downloaded? Run prepare-models.sh first!"
                  sys.exit(1)
    
              self.asr.set_property("fst", "asr/model2/HCLG.fst")
    
    Benoit Favre's avatar
    Benoit Favre committed
              self.asr.set_property("model", "asr/final.mdl")
    
              self.asr.set_property("word-syms", "asr/model2/words.txt")
    
    Benoit Favre's avatar
    Benoit Favre committed
              self.asr.set_property("feature-type", "mfcc")
              self.asr.set_property("mfcc-config", "asr/conf/mfcc.conf")
              self.asr.set_property("ivector-extraction-config", "asr/conf/ivector_extractor.fixed.conf")
              self.asr.set_property("max-active", 7000) # 7000
              self.asr.set_property("beam", 11.0) # 11
              self.asr.set_property("lattice-beam", 6.0) # 6
              self.asr.set_property("do-endpointing", True)
              self.asr.set_property("endpoint-silence-phones", "1:2:3:4:5:6:7:8:9:10")
            else:
              print >> sys.stderr, "Couldn't create the kaldinnet2onlinedecoder element. "
              if os.environ.has_key("GST_PLUGIN_PATH"):
                print >> sys.stderr, "Have you compiled the Kaldi GStreamer plugin?"
              else:
                print >> sys.stderr, "You probably need to set the GST_PLUGIN_PATH envoronment variable"
                print >> sys.stderr, "Try running: GST_PLUGIN_PATH=../src %s" % sys.argv[0]
              sys.exit();
            
            # initially silence the decoder
            self.asr.set_property("silent", True)
            
            self.pipeline = Gst.Pipeline()
            for element in [self.pulsesrc, self.audioconvert, self.audioresample, self.asr, self.fakesink]:
                self.pipeline.add(element)         
            self.pulsesrc.link(self.audioconvert)
            self.audioconvert.link(self.audioresample)
            self.audioresample.link(self.asr)
            self.asr.link(self.fakesink)    
      
            self.asr.connect('partial-result', self._on_partial_result)
            self.asr.connect('final-result', self._on_final_result)
            self.pipeline.set_state(Gst.State.PLAYING)
    
            GObject.idle_add(self._finished_loading_asr)
    
        def load_state(self):
            try:
                with open('state.txt') as fp:
                    self.asr.set_property('adaptation-state', fp.read())
            except:
              print >> sys.stderr, 'failed to load asr state'
    
        def save_state(self):
            if hasattr(self, 'asr'):
                state = self.asr.get_property('adaptation-state')
                try:
                    with open('state.txt', 'w') as fp:
                        fp.write(state)
                except:
                    print >> sys.stderr, 'failed to save asr state'
    
        def _started_loading_asr(self):
            self.button.set_sensitive(False)
            self.button.set_label("Loading...")
    
        def _finished_loading_asr(self):
            self.button.set_label("Speak")
            self.button.set_sensitive(True)
    
        def _on_partial_result(self, asr, hyp):
            """Delete any previous selection, insert text and select it."""
            Gdk.threads_enter()
            if len(self.hyp) == 0:
                self.hyp = ['']
            self.hyp[-1] = hyp
    
            if self.partial_hyp_callback:
                self.partial_hyp_callback(self.hyp)
    
    Benoit Favre's avatar
    Benoit Favre committed
    
            hyp += '...'
            self.insert = self.buffer.get_iter_at_line(self.buffer.get_line_count() - 1)
            self.buffer.delete(self.insert, self.buffer.get_end_iter())
            self.buffer.insert(self.insert, hyp)
    
            Gdk.threads_leave()
                    
        def _on_final_result(self, asr, hyp):
            Gdk.threads_enter()
            if len(self.hyp) == 0:
                self.hyp = ['']
            self.hyp[-1] = hyp
            if self.hyp_callback:
                self.hyp_callback(self.hyp)
    
            self.insert = self.buffer.get_iter_at_line(self.buffer.get_line_count() - 1)
            self.buffer.delete(self.insert, self.buffer.get_end_iter())
            self.buffer.insert(self.insert, hyp + '\n...')
    
    
            self.hyp.append('')
    
    Benoit Favre's avatar
    Benoit Favre committed
            Gdk.threads_leave()
    
        def autoscroll(self, *args):
            adj = self.scrolled.get_vadjustment()
            adj.set_value(adj.get_upper() - adj.get_page_size())
    
        def button_clicked(self, button):
            """Handle button presses."""
            if button.get_label() == "Speak":
                button.set_label("Stop")
                self.asr.set_property("silent", False)
                self.hyp = []
                self.buffer.set_text('...')
            else:
                button.set_label("Speak")
                self.asr.set_property("silent", True)