Skip to content
Snippets Groups Projects
Commit 5f5ef055 authored by tanel's avatar tanel
Browse files

Added adaptation-state property, needs fairly recent Kaldi to compile

parent 09c35435
Branches
No related tags found
No related merge requests found
......@@ -69,6 +69,7 @@ enum {
PROP_FST,
PROP_WORD_SYMS,
PROP_DO_ENDPOINTING,
PROP_ADAPTATION_STATE,
PROP_LAST
};
......@@ -176,6 +177,14 @@ static void gst_kaldinnet2onlinedecoder_class_init(
FALSE,
(GParamFlags) G_PARAM_READWRITE));
g_object_class_install_property(
gobject_class,
PROP_ADAPTATION_STATE,
g_param_spec_string("adaptation-state", "Adaptation state",
"Current adaptation state, in stringified form, set to empty string to reset",
"",
(GParamFlags) G_PARAM_READWRITE));
gst_kaldinnet2onlinedecoder_signals[PARTIAL_RESULT_SIGNAL] = g_signal_new(
"partial-result", G_TYPE_FROM_CLASS(klass), G_SIGNAL_RUN_LAST,
G_STRUCT_OFFSET(Gstkaldinnet2onlinedecoderClass, partial_result),
......@@ -334,6 +343,8 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object,
guint prop_id,
const GValue * value,
GParamSpec * pspec) {
Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(object);
switch (prop_id) {
......@@ -355,6 +366,35 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object,
case PROP_DO_ENDPOINTING:
filter->do_endpointing = g_value_get_boolean(value);
break;
case PROP_ADAPTATION_STATE:
{
if (G_VALUE_HOLDS_STRING(value)) {
gchar * adaptation_state_string = g_value_dup_string(value);
if (strlen(adaptation_state_string) > 0) {
std::istringstream str(adaptation_state_string);
try {
filter->adaptation_state->Read(str, false);
} catch (std::runtime_error& e) {
GST_WARNING_OBJECT(filter, "Failed to read adaptation state from given string, resetting instead");
delete filter->adaptation_state;
filter->adaptation_state = new OnlineIvectorExtractorAdaptationState(
filter->feature_info->ivector_extractor_info);
}
} else {
GST_DEBUG_OBJECT(filter, "Resetting adaptation state");
delete filter->adaptation_state;
filter->adaptation_state = new OnlineIvectorExtractorAdaptationState(
filter->feature_info->ivector_extractor_info);
}
delete adaptation_state_string;
} else {
GST_DEBUG_OBJECT(filter, "Resetting adaptation state");
delete filter->adaptation_state;
filter->adaptation_state = new OnlineIvectorExtractorAdaptationState(
filter->feature_info->ivector_extractor_info);
}
}
break;
default:
if (prop_id >= PROP_LAST) {
const gchar* name = g_param_spec_get_name(pspec);
......@@ -402,6 +442,8 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object,
float tmp_float;
double tmp_double;
std::string tmp_string;
std::ostringstream string_stream;
Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(object);
......@@ -421,6 +463,11 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object,
case PROP_DO_ENDPOINTING:
g_value_set_boolean(value, filter->do_endpointing);
break;
case PROP_ADAPTATION_STATE:
string_stream.clear();
filter->adaptation_state->Write(string_stream, false);
g_value_set_string(value, string_stream.str().c_str());
break;
default:
if (prop_id >= PROP_LAST) {
const gchar* name = g_param_spec_get_name(pspec);
......@@ -547,14 +594,11 @@ static void gst_kaldinnet2onlinedecoder_loop(
int32 chunk_length = int32(filter->sample_rate * chunk_length_secs);
OnlineIvectorExtractorAdaptationState adaptation_state(
filter->feature_info->ivector_extractor_info);
bool more_data = true;
while (more_data) {
OnlineNnet2FeaturePipeline feature_pipeline(*(filter->feature_info));
feature_pipeline.SetAdaptationState(adaptation_state);
feature_pipeline.SetAdaptationState(*(filter->adaptation_state));
SingleUtteranceNnet2Decoder decoder(*(filter->nnet2_decoding_config),
*(filter->trans_model), *(filter->nnet),
......@@ -603,7 +647,7 @@ static void gst_kaldinnet2onlinedecoder_loop(
&tot_like, &num_words);
if (num_words > 0) {
// Only update adaptation state if the utterance was not empty
feature_pipeline.GetAdaptationState(&adaptation_state);
feature_pipeline.GetAdaptationState(filter->adaptation_state);
}
} else {
GST_DEBUG_OBJECT(filter, "Less than 0.1 seconds decoded, discarding");
......@@ -770,6 +814,9 @@ gst_kaldinnet2onlinedecoder_allocate(
return false;
}
filter->adaptation_state = new OnlineIvectorExtractorAdaptationState(
filter->feature_info->ivector_extractor_info);
}
return true;
}
......
......@@ -78,6 +78,7 @@ struct _Gstkaldinnet2onlinedecoder {
fst::SymbolTable *word_syms;
int sample_rate;
gboolean decoding;
OnlineIvectorExtractorAdaptationState *adaptation_state;
};
struct _Gstkaldinnet2onlinedecoderClass {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment