From 5f5ef0552e94101781d0faaae50a9761d850331c Mon Sep 17 00:00:00 2001 From: tanel <alumae@gmail.com> Date: Tue, 11 Nov 2014 17:54:15 +0200 Subject: [PATCH] Added adaptation-state property, needs fairly recent Kaldi to compile --- src/gstkaldinnet2onlinedecoder.cc | 57 ++++++++++++++++++++++++++++--- src/gstkaldinnet2onlinedecoder.h | 1 + 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/src/gstkaldinnet2onlinedecoder.cc b/src/gstkaldinnet2onlinedecoder.cc index 1783797..b06054f 100644 --- a/src/gstkaldinnet2onlinedecoder.cc +++ b/src/gstkaldinnet2onlinedecoder.cc @@ -69,6 +69,7 @@ enum { PROP_FST, PROP_WORD_SYMS, PROP_DO_ENDPOINTING, + PROP_ADAPTATION_STATE, PROP_LAST }; @@ -176,6 +177,14 @@ static void gst_kaldinnet2onlinedecoder_class_init( FALSE, (GParamFlags) G_PARAM_READWRITE)); + g_object_class_install_property( + gobject_class, + PROP_ADAPTATION_STATE, + g_param_spec_string("adaptation-state", "Adaptation state", + "Current adaptation state, in stringified form, set to empty string to reset", + "", + (GParamFlags) G_PARAM_READWRITE)); + gst_kaldinnet2onlinedecoder_signals[PARTIAL_RESULT_SIGNAL] = g_signal_new( "partial-result", G_TYPE_FROM_CLASS(klass), G_SIGNAL_RUN_LAST, G_STRUCT_OFFSET(Gstkaldinnet2onlinedecoderClass, partial_result), @@ -334,6 +343,8 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object, guint prop_id, const GValue * value, GParamSpec * pspec) { + + Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(object); switch (prop_id) { @@ -355,6 +366,35 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object, case PROP_DO_ENDPOINTING: filter->do_endpointing = g_value_get_boolean(value); break; + case PROP_ADAPTATION_STATE: + { + if (G_VALUE_HOLDS_STRING(value)) { + gchar * adaptation_state_string = g_value_dup_string(value); + if (strlen(adaptation_state_string) > 0) { + std::istringstream str(adaptation_state_string); + try { + filter->adaptation_state->Read(str, false); + } catch (std::runtime_error& e) { + GST_WARNING_OBJECT(filter, "Failed to read adaptation state from given string, resetting instead"); + delete filter->adaptation_state; + filter->adaptation_state = new OnlineIvectorExtractorAdaptationState( + filter->feature_info->ivector_extractor_info); + } + } else { + GST_DEBUG_OBJECT(filter, "Resetting adaptation state"); + delete filter->adaptation_state; + filter->adaptation_state = new OnlineIvectorExtractorAdaptationState( + filter->feature_info->ivector_extractor_info); + } + delete adaptation_state_string; + } else { + GST_DEBUG_OBJECT(filter, "Resetting adaptation state"); + delete filter->adaptation_state; + filter->adaptation_state = new OnlineIvectorExtractorAdaptationState( + filter->feature_info->ivector_extractor_info); + } + } + break; default: if (prop_id >= PROP_LAST) { const gchar* name = g_param_spec_get_name(pspec); @@ -402,6 +442,8 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object, float tmp_float; double tmp_double; std::string tmp_string; + std::ostringstream string_stream; + Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(object); @@ -421,6 +463,11 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object, case PROP_DO_ENDPOINTING: g_value_set_boolean(value, filter->do_endpointing); break; + case PROP_ADAPTATION_STATE: + string_stream.clear(); + filter->adaptation_state->Write(string_stream, false); + g_value_set_string(value, string_stream.str().c_str()); + break; default: if (prop_id >= PROP_LAST) { const gchar* name = g_param_spec_get_name(pspec); @@ -547,14 +594,11 @@ static void gst_kaldinnet2onlinedecoder_loop( int32 chunk_length = int32(filter->sample_rate * chunk_length_secs); - OnlineIvectorExtractorAdaptationState adaptation_state( - filter->feature_info->ivector_extractor_info); - bool more_data = true; while (more_data) { OnlineNnet2FeaturePipeline feature_pipeline(*(filter->feature_info)); - feature_pipeline.SetAdaptationState(adaptation_state); + feature_pipeline.SetAdaptationState(*(filter->adaptation_state)); SingleUtteranceNnet2Decoder decoder(*(filter->nnet2_decoding_config), *(filter->trans_model), *(filter->nnet), @@ -603,7 +647,7 @@ static void gst_kaldinnet2onlinedecoder_loop( &tot_like, &num_words); if (num_words > 0) { // Only update adaptation state if the utterance was not empty - feature_pipeline.GetAdaptationState(&adaptation_state); + feature_pipeline.GetAdaptationState(filter->adaptation_state); } } else { GST_DEBUG_OBJECT(filter, "Less than 0.1 seconds decoded, discarding"); @@ -770,6 +814,9 @@ gst_kaldinnet2onlinedecoder_allocate( return false; } + filter->adaptation_state = new OnlineIvectorExtractorAdaptationState( + filter->feature_info->ivector_extractor_info); + } return true; } diff --git a/src/gstkaldinnet2onlinedecoder.h b/src/gstkaldinnet2onlinedecoder.h index b24ce10..ef02fae 100644 --- a/src/gstkaldinnet2onlinedecoder.h +++ b/src/gstkaldinnet2onlinedecoder.h @@ -78,6 +78,7 @@ struct _Gstkaldinnet2onlinedecoder { fst::SymbolTable *word_syms; int sample_rate; gboolean decoding; + OnlineIvectorExtractorAdaptationState *adaptation_state; }; struct _Gstkaldinnet2onlinedecoderClass { -- GitLab