diff --git a/src/Makefile b/src/Makefile index 9605899a2e97b4101155698f5343219538def407..fa772430cae342aaaa45f659e9f972299e775082 100644 --- a/src/Makefile +++ b/src/Makefile @@ -1,10 +1,10 @@ all: -KALDI_ROOT?=~/tools/kaldi-online +KALDI_ROOT?=/home/tanel/tools/kaldi-online ifeq ("$(wildcard $(KALDI_ROOT)/src/kaldi.mk)","") $(error Cannot find Kaldi's makefile $(KALDI_ROOT)/src/kaldi.mk. \ -Specify Kaldi's root directory using KALDI_ROOT when issuing make, e.g.: `KALDI_ROOT=~/tools/kaldi-trunk make` ) +Specify Kaldi's root directory using KALDI_ROOT when issuing make, e.g.: `KALDI_ROOT=/home/tanel/tools/kaldi-trunk make` ) endif include $(KALDI_ROOT)/src/kaldi.mk diff --git a/src/gstkaldinnet2onlinedecoder.cc b/src/gstkaldinnet2onlinedecoder.cc index e79f05506841a36b2c493032620890910d2b5249..39601f7ae0c77cf4e1930e2321f0da0e9c01734a 100644 --- a/src/gstkaldinnet2onlinedecoder.cc +++ b/src/gstkaldinnet2onlinedecoder.cc @@ -78,7 +78,7 @@ enum { /* the capabilities of the inputs and outputs. * */ -static GstStaticPadTemplate sink_factory = +static GstStaticPadTemplate sink_template = GST_STATIC_PAD_TEMPLATE("sink", GST_PAD_SINK, GST_PAD_ALWAYS, @@ -86,9 +86,9 @@ GST_STATIC_PAD_TEMPLATE("sink", "audio/x-raw, " "format = (string) S16LE, " "channels = (int) 1, " - "rate = (int) 16000 ")); + "rate = (int) [ 1, MAX ]")); -static GstStaticPadTemplate src_factory = +static GstStaticPadTemplate src_template = GST_STATIC_PAD_TEMPLATE("src", GST_PAD_SRC, GST_PAD_ALWAYS, @@ -121,6 +121,8 @@ static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad, static GstStateChangeReturn gst_kaldinnet2onlinedecoder_change_state( GstElement *element, GstStateChange transition); +static gboolean gst_kaldinnet2onlinedecoder_query (GstPad *pad, GstObject * parent, GstQuery * query); + static void gst_kaldinnet2onlinedecoder_finalize(GObject * object); /* GObject vmethod implementations */ @@ -191,10 +193,11 @@ static void gst_kaldinnet2onlinedecoder_class_init( gstelement_class, "KaldiNNet2OnlineDecoder", "Speech/Audio", "Convert speech to text", "Tanel Alumae <tanel.alumae@phon.ioc.ee>"); + gst_element_class_add_pad_template(gstelement_class, - gst_static_pad_template_get(&src_factory)); + gst_static_pad_template_get(&src_template)); gst_element_class_add_pad_template( - gstelement_class, gst_static_pad_template_get(&sink_factory)); + gstelement_class, gst_static_pad_template_get(&sink_template)); } /* initialize the new element @@ -211,17 +214,21 @@ static void gst_kaldinnet2onlinedecoder_init( double tmp_double; std::string tmp_string; - filter->sinkpad = gst_pad_new_from_static_template(&sink_factory, "sink"); + filter->sinkpad = NULL; + + filter->sinkpad = gst_pad_new_from_static_template(&sink_template, "sink"); gst_pad_set_event_function( filter->sinkpad, GST_DEBUG_FUNCPTR(gst_kaldinnet2onlinedecoder_sink_event)); gst_pad_set_chain_function( filter->sinkpad, GST_DEBUG_FUNCPTR(gst_kaldinnet2onlinedecoder_chain)); + gst_pad_set_query_function( + filter->sinkpad, GST_DEBUG_FUNCPTR(gst_kaldinnet2onlinedecoder_query)); - gst_pad_use_fixed_caps(filter->sinkpad); gst_element_add_pad(GST_ELEMENT(filter), filter->sinkpad); - filter->srcpad = gst_pad_new_from_static_template(&src_factory, "src"); + + filter->srcpad = gst_pad_new_from_static_template(&src_template, "src"); gst_pad_use_fixed_caps(filter->srcpad); gst_element_add_pad(GST_ELEMENT(filter), filter->srcpad); @@ -240,6 +247,10 @@ static void gst_kaldinnet2onlinedecoder_init( filter->feature_config->Register(filter->simple_options); filter->nnet2_decoding_config->Register(filter->simple_options); + // will be set later + filter->feature_info = NULL; + filter->sample_rate = 0; + // init properties from various Kaldi Opts GstElementClass * klass = GST_ELEMENT_GET_CLASS(filter); @@ -316,7 +327,6 @@ static void gst_kaldinnet2onlinedecoder_init( } i += 1; } - } static void gst_kaldinnet2onlinedecoder_set_property(GObject * object, @@ -533,7 +543,7 @@ static void gst_kaldinnet2onlinedecoder_loop( BaseFloat chunk_length_secs = 0.05; BaseFloat traceback_period_secs = 1.0; - int32 chunk_length = int32(16000 * chunk_length_secs); + int32 chunk_length = int32(filter->sample_rate * chunk_length_secs); bool more_data = true; while (more_data) { @@ -556,7 +566,7 @@ static void gst_kaldinnet2onlinedecoder_loop( while (true) { more_data = filter->audio_source->Read(&wave_part); //GST_DEBUG_OBJECT(filter, "Read %d samples", wave_part.Dim()); - feature_pipeline.AcceptWaveform(16000, wave_part); + feature_pipeline.AcceptWaveform(filter->sample_rate, wave_part); //GST_DEBUG_OBJECT(filter, "Advancing decoding.."); if (!more_data) { feature_pipeline.InputFinished(); @@ -600,6 +610,73 @@ static void gst_kaldinnet2onlinedecoder_loop( /* GstElement vmethod implementations */ +static gboolean +gst_kaldinnet2onlinedecoder_query (GstPad *pad, GstObject * parent, GstQuery * query) { + gboolean ret; + Gstkaldinnet2onlinedecoder *filter; + + filter = GST_KALDINNET2ONLINEDECODER(parent); + + switch (GST_QUERY_TYPE (query)) { + case GST_QUERY_CAPS: { + if (filter->feature_info == NULL) { + filter->feature_info = new OnlineNnet2FeaturePipelineInfo(*(filter->feature_config)); + filter->sample_rate = (int) filter->feature_info->mfcc_opts.frame_opts.samp_freq; + } + GstCaps *new_caps = gst_caps_new_simple ("audio/x-raw", + "format", G_TYPE_STRING, "S16LE", + "rate", G_TYPE_INT, filter->sample_rate, + "channels", G_TYPE_INT, 1, NULL); + + +// GstPad *otherpad; +// GstCaps *temp, *caps, *filt, *tcaps; +// gint i; +// +// otherpad = (pad == filter->srcpad) ? filter->sinkpad : +// filter->srcpad; +// caps = gst_pad_get_allowed_caps (otherpad); +// +// gst_query_parse_caps (query, &filt); +// +// /* We support *any* samplerate, indifferent from the samplerate +// * supported by the linked elements on both sides. */ +// for (i = 0; i < gst_caps_get_size (caps); i++) { +// GstStructure *structure = gst_caps_get_structure (caps, i); +// +// gst_structure_remove_field (structure, "rate"); +// } +// +// /* make sure we only return results that intersect our +// * padtemplate */ +// tcaps = gst_pad_get_pad_template_caps (pad); +// if (tcaps) { +// temp = gst_caps_intersect (caps, tcaps); +// gst_caps_unref (caps); +// gst_caps_unref (tcaps); +// caps = temp; +// } +// /* filter against the query filter when needed */ +// if (filt) { +// temp = gst_caps_intersect (caps, filt); +// gst_caps_unref (caps); +// caps = temp; +// } + GST_DEBUG_OBJECT (filter, "Setting caps query result: %" GST_PTR_FORMAT, new_caps); + gst_query_set_caps_result (query, new_caps); + gst_caps_unref (new_caps); + ret = TRUE; + break; + } + default: + ret = gst_pad_query_default (pad, parent, query); + break; + } + return ret; +} + + + /* this function handles sink events */ static gboolean gst_kaldinnet2onlinedecoder_sink_event(GstPad * pad, GstObject * parent, @@ -647,9 +724,8 @@ static gboolean gst_kaldinnet2onlinedecoder_sink_event(GstPad * pad, static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad, GstObject * parent, GstBuffer * buf) { - Gstkaldinnet2onlinedecoder *filter; + Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(parent); - filter = GST_KALDINNET2ONLINEDECODER(parent); if (G_UNLIKELY(!filter->audio_source)) goto not_negotiated; @@ -669,15 +745,21 @@ static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad, } } -static bool gst_kaldinnet2onlinedecoder_allocate( + +static bool +gst_kaldinnet2onlinedecoder_allocate( Gstkaldinnet2onlinedecoder * filter) { - if (!filter->feature_info) { + if (!filter->decode_fst) { GST_INFO_OBJECT(filter, "Loading Kaldi models and feature extractor"); filter->audio_source = new GstBufferSource(); - filter->feature_info = new OnlineNnet2FeaturePipelineInfo( - *(filter->feature_config)); + if (filter->feature_info == NULL) { + filter->feature_info = new OnlineNnet2FeaturePipelineInfo(*(filter->feature_config)); + filter->sample_rate = (int) filter->feature_info->mfcc_opts.frame_opts.samp_freq; + } + + filter->sample_rate = (int) filter->feature_info->mfcc_opts.frame_opts.samp_freq; filter->trans_model = new TransitionModel(); filter->nnet = new nnet2::AmNnet(); @@ -710,6 +792,7 @@ static bool gst_kaldinnet2onlinedecoder_deallocate( static GstStateChangeReturn gst_kaldinnet2onlinedecoder_change_state( GstElement *element, GstStateChange transition) { + GstStateChangeReturn ret = GST_STATE_CHANGE_SUCCESS; Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(element); diff --git a/src/gstkaldinnet2onlinedecoder.h b/src/gstkaldinnet2onlinedecoder.h index f431e8f78f1d5f294b2100ec173762c017c3f6cd..01fceb8d142b3838daf26632271dd5dd84cf9880 100644 --- a/src/gstkaldinnet2onlinedecoder.h +++ b/src/gstkaldinnet2onlinedecoder.h @@ -84,6 +84,8 @@ struct _Gstkaldinnet2onlinedecoder GstPad *sinkpad, *srcpad; + GstCaps *sink_caps; + gboolean silent; gboolean do_endpointing; GstBufferSource *audio_source; @@ -102,6 +104,7 @@ struct _Gstkaldinnet2onlinedecoder nnet2::AmNnet *nnet; fst::Fst<fst::StdArc> *decode_fst; fst::SymbolTable *word_syms; + int sample_rate; }; struct _Gstkaldinnet2onlinedecoderClass