From 5f5ef0552e94101781d0faaae50a9761d850331c Mon Sep 17 00:00:00 2001
From: tanel <alumae@gmail.com>
Date: Tue, 11 Nov 2014 17:54:15 +0200
Subject: [PATCH] Added adaptation-state property, needs fairly recent Kaldi to
 compile

---
 src/gstkaldinnet2onlinedecoder.cc | 57 ++++++++++++++++++++++++++++---
 src/gstkaldinnet2onlinedecoder.h  |  1 +
 2 files changed, 53 insertions(+), 5 deletions(-)

diff --git a/src/gstkaldinnet2onlinedecoder.cc b/src/gstkaldinnet2onlinedecoder.cc
index 1783797..b06054f 100644
--- a/src/gstkaldinnet2onlinedecoder.cc
+++ b/src/gstkaldinnet2onlinedecoder.cc
@@ -69,6 +69,7 @@ enum {
   PROP_FST,
   PROP_WORD_SYMS,
   PROP_DO_ENDPOINTING,
+  PROP_ADAPTATION_STATE,
   PROP_LAST
 };
 
@@ -176,6 +177,14 @@ static void gst_kaldinnet2onlinedecoder_class_init(
           FALSE,
           (GParamFlags) G_PARAM_READWRITE));
 
+  g_object_class_install_property(
+      gobject_class,
+      PROP_ADAPTATION_STATE,
+      g_param_spec_string("adaptation-state", "Adaptation state",
+                          "Current adaptation state, in stringified form, set to empty string to reset",
+                          "",
+                          (GParamFlags) G_PARAM_READWRITE));
+
   gst_kaldinnet2onlinedecoder_signals[PARTIAL_RESULT_SIGNAL] = g_signal_new(
       "partial-result", G_TYPE_FROM_CLASS(klass), G_SIGNAL_RUN_LAST,
       G_STRUCT_OFFSET(Gstkaldinnet2onlinedecoderClass, partial_result),
@@ -334,6 +343,8 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object,
                                                      guint prop_id,
                                                      const GValue * value,
                                                      GParamSpec * pspec) {
+
+
   Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(object);
 
   switch (prop_id) {
@@ -355,6 +366,35 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object,
     case PROP_DO_ENDPOINTING:
       filter->do_endpointing = g_value_get_boolean(value);
       break;
+    case PROP_ADAPTATION_STATE:
+      {
+        if (G_VALUE_HOLDS_STRING(value)) {
+          gchar * adaptation_state_string = g_value_dup_string(value);
+          if (strlen(adaptation_state_string) > 0) {
+            std::istringstream str(adaptation_state_string);
+            try {
+              filter->adaptation_state->Read(str, false);
+            } catch (std::runtime_error& e) {
+              GST_WARNING_OBJECT(filter, "Failed to read adaptation state from given string, resetting instead");
+              delete filter->adaptation_state;
+              filter->adaptation_state = new OnlineIvectorExtractorAdaptationState(
+                  filter->feature_info->ivector_extractor_info);
+            }
+          } else {
+            GST_DEBUG_OBJECT(filter, "Resetting adaptation state");
+            delete filter->adaptation_state;
+            filter->adaptation_state = new OnlineIvectorExtractorAdaptationState(
+                filter->feature_info->ivector_extractor_info);
+          }
+          delete adaptation_state_string;
+        } else {
+          GST_DEBUG_OBJECT(filter, "Resetting adaptation state");
+          delete filter->adaptation_state;
+          filter->adaptation_state = new OnlineIvectorExtractorAdaptationState(
+              filter->feature_info->ivector_extractor_info);
+        }
+      }
+      break;
     default:
       if (prop_id >= PROP_LAST) {
         const gchar* name = g_param_spec_get_name(pspec);
@@ -402,6 +442,8 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object,
   float tmp_float;
   double tmp_double;
   std::string tmp_string;
+  std::ostringstream string_stream;
+
 
   Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(object);
 
@@ -421,6 +463,11 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object,
     case PROP_DO_ENDPOINTING:
       g_value_set_boolean(value, filter->do_endpointing);
       break;
+    case PROP_ADAPTATION_STATE:
+      string_stream.clear();
+      filter->adaptation_state->Write(string_stream, false);
+      g_value_set_string(value, string_stream.str().c_str());
+      break;
     default:
       if (prop_id >= PROP_LAST) {
         const gchar* name = g_param_spec_get_name(pspec);
@@ -547,14 +594,11 @@ static void gst_kaldinnet2onlinedecoder_loop(
 
   int32 chunk_length = int32(filter->sample_rate * chunk_length_secs);
 
-  OnlineIvectorExtractorAdaptationState adaptation_state(
-      filter->feature_info->ivector_extractor_info);
-
   bool more_data = true;
   while (more_data) {
 
     OnlineNnet2FeaturePipeline feature_pipeline(*(filter->feature_info));
-    feature_pipeline.SetAdaptationState(adaptation_state);
+    feature_pipeline.SetAdaptationState(*(filter->adaptation_state));
 
     SingleUtteranceNnet2Decoder decoder(*(filter->nnet2_decoding_config),
                                         *(filter->trans_model), *(filter->nnet),
@@ -603,7 +647,7 @@ static void gst_kaldinnet2onlinedecoder_loop(
                                                &tot_like, &num_words);
       if (num_words > 0) {
         // Only update adaptation state if the utterance was not empty
-        feature_pipeline.GetAdaptationState(&adaptation_state);
+        feature_pipeline.GetAdaptationState(filter->adaptation_state);
       }
     } else {
       GST_DEBUG_OBJECT(filter, "Less than 0.1 seconds decoded, discarding");
@@ -770,6 +814,9 @@ gst_kaldinnet2onlinedecoder_allocate(
       return false;
     }
 
+    filter->adaptation_state = new OnlineIvectorExtractorAdaptationState(
+        filter->feature_info->ivector_extractor_info);
+
   }
   return true;
 }
diff --git a/src/gstkaldinnet2onlinedecoder.h b/src/gstkaldinnet2onlinedecoder.h
index b24ce10..ef02fae 100644
--- a/src/gstkaldinnet2onlinedecoder.h
+++ b/src/gstkaldinnet2onlinedecoder.h
@@ -78,6 +78,7 @@ struct _Gstkaldinnet2onlinedecoder {
   fst::SymbolTable *word_syms;
   int sample_rate;
   gboolean decoding;
+  OnlineIvectorExtractorAdaptationState *adaptation_state;
 };
 
 struct _Gstkaldinnet2onlinedecoderClass {
-- 
GitLab