Skip to content
Snippets Groups Projects
Commit 1f5faf4a authored by tanel's avatar tanel
Browse files

saving adaptation state between utterances of a single session

parent 08ae041c
No related branches found
No related tags found
No related merge requests found
all: all:
KALDI_ROOT?=/home/tanel/tools/kaldi-online KALDI_ROOT?=/home/tanel/tools/kaldi-trunk
ifeq ("$(wildcard $(KALDI_ROOT)/src/kaldi.mk)","") ifeq ("$(wildcard $(KALDI_ROOT)/src/kaldi.mk)","")
$(error Cannot find Kaldi's makefile $(KALDI_ROOT)/src/kaldi.mk. \ $(error Cannot find Kaldi's makefile $(KALDI_ROOT)/src/kaldi.mk. \
......
...@@ -462,7 +462,7 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object, ...@@ -462,7 +462,7 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object,
static void gst_kaldinnet2onlinedecoder_final_result( static void gst_kaldinnet2onlinedecoder_final_result(
Gstkaldinnet2onlinedecoder * filter, const CompactLattice &clat, Gstkaldinnet2onlinedecoder * filter, const CompactLattice &clat,
int64 *tot_num_frames, double *tot_like) { int64 *tot_num_frames, double *tot_like, guint *num_words) {
if (clat.NumStates() == 0) { if (clat.NumStates() == 0) {
KALDI_WARN<< "Empty lattice."; KALDI_WARN<< "Empty lattice.";
return; return;
...@@ -499,6 +499,7 @@ static void gst_kaldinnet2onlinedecoder_final_result( ...@@ -499,6 +499,7 @@ static void gst_kaldinnet2onlinedecoder_final_result(
GST_DEBUG_OBJECT(filter, "Final: %s", sentence.str().c_str()); GST_DEBUG_OBJECT(filter, "Final: %s", sentence.str().c_str());
guint hyp_length = sentence.str().length(); guint hyp_length = sentence.str().length();
*num_words = hyp_length;
if (hyp_length > 0) { if (hyp_length > 0) {
GstBuffer *buffer = gst_buffer_new_and_alloc(hyp_length + 1); GstBuffer *buffer = gst_buffer_new_and_alloc(hyp_length + 1);
gst_buffer_fill(buffer, 0, sentence.str().c_str(), hyp_length); gst_buffer_fill(buffer, 0, sentence.str().c_str(), hyp_length);
...@@ -545,10 +546,11 @@ static void gst_kaldinnet2onlinedecoder_loop( ...@@ -545,10 +546,11 @@ static void gst_kaldinnet2onlinedecoder_loop(
int32 chunk_length = int32(filter->sample_rate * chunk_length_secs); int32 chunk_length = int32(filter->sample_rate * chunk_length_secs);
OnlineIvectorExtractorAdaptationState adaptation_state(
filter->feature_info->ivector_extractor_info);
bool more_data = true; bool more_data = true;
while (more_data) { while (more_data) {
OnlineIvectorExtractorAdaptationState adaptation_state(
filter->feature_info->ivector_extractor_info);
OnlineNnet2FeaturePipeline feature_pipeline(*(filter->feature_info)); OnlineNnet2FeaturePipeline feature_pipeline(*(filter->feature_info));
feature_pipeline.SetAdaptationState(adaptation_state); feature_pipeline.SetAdaptationState(adaptation_state);
...@@ -593,9 +595,25 @@ static void gst_kaldinnet2onlinedecoder_loop( ...@@ -593,9 +595,25 @@ static void gst_kaldinnet2onlinedecoder_loop(
GST_DEBUG_OBJECT(filter, "Lattice done"); GST_DEBUG_OBJECT(filter, "Lattice done");
double tot_like = 0.0; double tot_like = 0.0;
int64 num_frames = 0; int64 num_frames = 0;
guint num_words = 0;
gst_kaldinnet2onlinedecoder_final_result(filter, clat, &num_frames, gst_kaldinnet2onlinedecoder_final_result(filter, clat, &num_frames,
&tot_like); &tot_like, &num_words);
if (num_words > 0) {
// Only update adaptation state if the utterance was not empty
// We might avoid updating the adaptation state if
// we felt the utterance had low confidence.
feature_pipeline.GetAdaptationState(&adaptation_state);
}
} }
// We should also push out adaptation state now, something like this
// bool binary = false;
// Output ko("adaptation_state.txt", binary);
// adaptation_state.Write(ko.Stream(), binary);
// ko.Close();
GST_DEBUG_OBJECT(filter, "Finished decoding loop"); GST_DEBUG_OBJECT(filter, "Finished decoding loop");
GST_DEBUG_OBJECT(filter, "Pushing EOS event"); GST_DEBUG_OBJECT(filter, "Pushing EOS event");
gst_pad_push_event(filter->srcpad, gst_event_new_eos()); gst_pad_push_event(filter->srcpad, gst_event_new_eos());
...@@ -699,6 +717,7 @@ static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad, ...@@ -699,6 +717,7 @@ static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad,
if (G_UNLIKELY(!filter->audio_source)) if (G_UNLIKELY(!filter->audio_source))
goto not_negotiated; goto not_negotiated;
if (!filter->silent) { if (!filter->silent) {
GST_DEBUG_OBJECT(filter, "Pushing buffer of length %zu", gst_buffer_get_size(buf));
filter->audio_source->PushBuffer(buf); filter->audio_source->PushBuffer(buf);
} }
gst_buffer_unref(buf); gst_buffer_unref(buf);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment