From 28dd859da73aec5ed235c9175a410ac8f0ef0dec Mon Sep 17 00:00:00 2001
From: tanel <alumae@gmail.com>
Date: Wed, 7 Jan 2015 17:22:00 +0200
Subject: [PATCH] Made chunk length configurable, can be increased to e.g. 0.2
 to get some speedup when multithreaded BLAS is used

---
 demo/transcribe-audio.sh          |  1 +
 src/gstkaldinnet2onlinedecoder.cc | 36 ++++++++++++++++++++++---------
 src/gstkaldinnet2onlinedecoder.h  |  1 +
 3 files changed, 28 insertions(+), 10 deletions(-)

diff --git a/demo/transcribe-audio.sh b/demo/transcribe-audio.sh
index dc771b5..2b4d0e9 100755
--- a/demo/transcribe-audio.sh
+++ b/demo/transcribe-audio.sh
@@ -28,4 +28,5 @@ kaldinnet2onlinedecoder \
   lattice-beam=5.0 \
   do-endpointing=true \
   endpoint-silence-phones="1:2:3:4:5:6:7:8:9:10" \
+  chunk-length-in-secs=0.2 \
 ! filesink location=/dev/stdout buffer-mode=2
diff --git a/src/gstkaldinnet2onlinedecoder.cc b/src/gstkaldinnet2onlinedecoder.cc
index 0fd8e81..4759258 100644
--- a/src/gstkaldinnet2onlinedecoder.cc
+++ b/src/gstkaldinnet2onlinedecoder.cc
@@ -72,6 +72,7 @@ enum {
   PROP_ADAPTATION_STATE,
   PROP_INVERSE_SCALE,
   PROP_LMWT_SCALE,
+  PROP_CHUNK_LENGTH_IN_SECS,
   PROP_LAST
 };
 
@@ -79,6 +80,7 @@ enum {
 #define DEFAULT_FST             "HCLG.fst"
 #define DEFAULT_WORD_SYMS       "words.txt"
 #define DEFAULT_LMWT_SCALE	1.0
+#define DEFAULT_CHUNK_LENGTH_IN_SECS  0.05
 
 /* the capabilities of the inputs and outputs.
  *
@@ -208,6 +210,17 @@ static void gst_kaldinnet2onlinedecoder_class_init(
           DEFAULT_LMWT_SCALE,
           (GParamFlags) G_PARAM_READWRITE));
 
+  g_object_class_install_property(
+      gobject_class,
+      PROP_CHUNK_LENGTH_IN_SECS,
+      g_param_spec_float(
+          "chunk-length-in-secs", "Length of a audio chunk that is processed at a time",
+          "Smaller values decrease latency, bigger values (e.g. 0.2) improve speed if multithreaded BLAS/MKL is used",
+          0.05,
+          G_MAXFLOAT,
+          DEFAULT_CHUNK_LENGTH_IN_SECS,
+          (GParamFlags) G_PARAM_READWRITE));
+
   gst_kaldinnet2onlinedecoder_signals[PARTIAL_RESULT_SIGNAL] = g_signal_new(
       "partial-result", G_TYPE_FROM_CLASS(klass), G_SIGNAL_RUN_LAST,
       G_STRUCT_OFFSET(Gstkaldinnet2onlinedecoderClass, partial_result),
@@ -285,6 +298,7 @@ static void gst_kaldinnet2onlinedecoder_init(
   filter->decoding = false;
   filter->lmwt_scale = DEFAULT_LMWT_SCALE;
   filter->inverse_scale = FALSE;
+  filter->chunk_length_in_secs = DEFAULT_CHUNK_LENGTH_IN_SECS;
 
   // init properties from various Kaldi Opts
   GstElementClass * klass = GST_ELEMENT_GET_CLASS(filter);
@@ -397,6 +411,9 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object,
     case PROP_LMWT_SCALE:
       filter->lmwt_scale = g_value_get_float(value);
       break;
+    case PROP_CHUNK_LENGTH_IN_SECS:
+      filter->chunk_length_in_secs = g_value_get_float(value);
+      break;
     case PROP_ADAPTATION_STATE:
       {
         if (G_VALUE_HOLDS_STRING(value)) {
@@ -500,6 +517,9 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object,
     case PROP_LMWT_SCALE:
       g_value_set_float(value, filter->lmwt_scale);
       break;
+    case PROP_CHUNK_LENGTH_IN_SECS:
+      g_value_set_float(value, filter->chunk_length_in_secs);
+      break;
     case PROP_ADAPTATION_STATE:
       string_stream.clear();
       if (filter->adaptation_state) {
@@ -639,10 +659,9 @@ static void gst_kaldinnet2onlinedecoder_loop(
     Gstkaldinnet2onlinedecoder * filter) {
 
   GST_DEBUG_OBJECT(filter, "Starting decoding loop..");
-  BaseFloat chunk_length_secs = 0.05;
   BaseFloat traceback_period_secs = 1.0;
 
-  int32 chunk_length = int32(filter->sample_rate * chunk_length_secs);
+  int32 chunk_length = int32(filter->sample_rate * filter->chunk_length_in_secs);
 
   bool more_data = true;
   while (more_data) {
@@ -675,7 +694,7 @@ static void gst_kaldinnet2onlinedecoder_loop(
         GST_DEBUG_OBJECT(filter, "Endpoint detected!");
         break;
       }
-      num_seconds_decoded += chunk_length_secs;
+      num_seconds_decoded += filter->chunk_length_in_secs;
       if (num_seconds_decoded - last_traceback > traceback_period_secs) {
         Lattice lat;
         decoder.GetBestPath(false, &lat);
@@ -704,13 +723,6 @@ static void gst_kaldinnet2onlinedecoder_loop(
     }
   }
 
-  // TODO: write adaptation state to string and push it out
-  // bool binary = false;
-  // Output ko("adaptation_state.txt", binary);
-  // adaptation_state.Write(ko.Stream(), binary);
-  // ko.Close();
-
-
   GST_DEBUG_OBJECT(filter, "Finished decoding loop");
   GST_DEBUG_OBJECT(filter, "Pushing EOS event");
   gst_pad_push_event(filter->srcpad, gst_event_new_eos());
@@ -933,6 +945,10 @@ static void gst_kaldinnet2onlinedecoder_finalize(GObject * object) {
   if (filter->word_syms) {
     delete filter->word_syms;
   }
+  if (filter->adaptation_state) {
+    delete filter->adaptation_state;
+  }
+
   G_OBJECT_CLASS(parent_class)->finalize(object);
 }
 
diff --git a/src/gstkaldinnet2onlinedecoder.h b/src/gstkaldinnet2onlinedecoder.h
index 3adf837..e1fc7e6 100644
--- a/src/gstkaldinnet2onlinedecoder.h
+++ b/src/gstkaldinnet2onlinedecoder.h
@@ -80,6 +80,7 @@ struct _Gstkaldinnet2onlinedecoder {
   fst::SymbolTable *word_syms;
   int sample_rate;
   gboolean decoding;
+  float chunk_length_in_secs;
   OnlineIvectorExtractorAdaptationState *adaptation_state;
 };
 
-- 
GitLab