Skip to content
Snippets Groups Projects
Commit 3a764299 authored by tanel's avatar tanel
Browse files

now supports different sample rates

parent cde5f180
No related branches found
No related tags found
No related merge requests found
all: all:
KALDI_ROOT?=~/tools/kaldi-online KALDI_ROOT?=/home/tanel/tools/kaldi-online
ifeq ("$(wildcard $(KALDI_ROOT)/src/kaldi.mk)","") ifeq ("$(wildcard $(KALDI_ROOT)/src/kaldi.mk)","")
$(error Cannot find Kaldi's makefile $(KALDI_ROOT)/src/kaldi.mk. \ $(error Cannot find Kaldi's makefile $(KALDI_ROOT)/src/kaldi.mk. \
Specify Kaldi's root directory using KALDI_ROOT when issuing make, e.g.: `KALDI_ROOT=~/tools/kaldi-trunk make` ) Specify Kaldi's root directory using KALDI_ROOT when issuing make, e.g.: `KALDI_ROOT=/home/tanel/tools/kaldi-trunk make` )
endif endif
include $(KALDI_ROOT)/src/kaldi.mk include $(KALDI_ROOT)/src/kaldi.mk
......
...@@ -78,7 +78,7 @@ enum { ...@@ -78,7 +78,7 @@ enum {
/* the capabilities of the inputs and outputs. /* the capabilities of the inputs and outputs.
* *
*/ */
static GstStaticPadTemplate sink_factory = static GstStaticPadTemplate sink_template =
GST_STATIC_PAD_TEMPLATE("sink", GST_STATIC_PAD_TEMPLATE("sink",
GST_PAD_SINK, GST_PAD_SINK,
GST_PAD_ALWAYS, GST_PAD_ALWAYS,
...@@ -86,9 +86,9 @@ GST_STATIC_PAD_TEMPLATE("sink", ...@@ -86,9 +86,9 @@ GST_STATIC_PAD_TEMPLATE("sink",
"audio/x-raw, " "audio/x-raw, "
"format = (string) S16LE, " "format = (string) S16LE, "
"channels = (int) 1, " "channels = (int) 1, "
"rate = (int) 16000 ")); "rate = (int) [ 1, MAX ]"));
static GstStaticPadTemplate src_factory = static GstStaticPadTemplate src_template =
GST_STATIC_PAD_TEMPLATE("src", GST_STATIC_PAD_TEMPLATE("src",
GST_PAD_SRC, GST_PAD_SRC,
GST_PAD_ALWAYS, GST_PAD_ALWAYS,
...@@ -121,6 +121,8 @@ static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad, ...@@ -121,6 +121,8 @@ static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad,
static GstStateChangeReturn gst_kaldinnet2onlinedecoder_change_state( static GstStateChangeReturn gst_kaldinnet2onlinedecoder_change_state(
GstElement *element, GstStateChange transition); GstElement *element, GstStateChange transition);
static gboolean gst_kaldinnet2onlinedecoder_query (GstPad *pad, GstObject * parent, GstQuery * query);
static void gst_kaldinnet2onlinedecoder_finalize(GObject * object); static void gst_kaldinnet2onlinedecoder_finalize(GObject * object);
/* GObject vmethod implementations */ /* GObject vmethod implementations */
...@@ -191,10 +193,11 @@ static void gst_kaldinnet2onlinedecoder_class_init( ...@@ -191,10 +193,11 @@ static void gst_kaldinnet2onlinedecoder_class_init(
gstelement_class, "KaldiNNet2OnlineDecoder", "Speech/Audio", gstelement_class, "KaldiNNet2OnlineDecoder", "Speech/Audio",
"Convert speech to text", "Tanel Alumae <tanel.alumae@phon.ioc.ee>"); "Convert speech to text", "Tanel Alumae <tanel.alumae@phon.ioc.ee>");
gst_element_class_add_pad_template(gstelement_class, gst_element_class_add_pad_template(gstelement_class,
gst_static_pad_template_get(&src_factory)); gst_static_pad_template_get(&src_template));
gst_element_class_add_pad_template( gst_element_class_add_pad_template(
gstelement_class, gst_static_pad_template_get(&sink_factory)); gstelement_class, gst_static_pad_template_get(&sink_template));
} }
/* initialize the new element /* initialize the new element
...@@ -211,17 +214,21 @@ static void gst_kaldinnet2onlinedecoder_init( ...@@ -211,17 +214,21 @@ static void gst_kaldinnet2onlinedecoder_init(
double tmp_double; double tmp_double;
std::string tmp_string; std::string tmp_string;
filter->sinkpad = gst_pad_new_from_static_template(&sink_factory, "sink"); filter->sinkpad = NULL;
filter->sinkpad = gst_pad_new_from_static_template(&sink_template, "sink");
gst_pad_set_event_function( gst_pad_set_event_function(
filter->sinkpad, filter->sinkpad,
GST_DEBUG_FUNCPTR(gst_kaldinnet2onlinedecoder_sink_event)); GST_DEBUG_FUNCPTR(gst_kaldinnet2onlinedecoder_sink_event));
gst_pad_set_chain_function( gst_pad_set_chain_function(
filter->sinkpad, GST_DEBUG_FUNCPTR(gst_kaldinnet2onlinedecoder_chain)); filter->sinkpad, GST_DEBUG_FUNCPTR(gst_kaldinnet2onlinedecoder_chain));
gst_pad_set_query_function(
filter->sinkpad, GST_DEBUG_FUNCPTR(gst_kaldinnet2onlinedecoder_query));
gst_pad_use_fixed_caps(filter->sinkpad);
gst_element_add_pad(GST_ELEMENT(filter), filter->sinkpad); gst_element_add_pad(GST_ELEMENT(filter), filter->sinkpad);
filter->srcpad = gst_pad_new_from_static_template(&src_factory, "src");
filter->srcpad = gst_pad_new_from_static_template(&src_template, "src");
gst_pad_use_fixed_caps(filter->srcpad); gst_pad_use_fixed_caps(filter->srcpad);
gst_element_add_pad(GST_ELEMENT(filter), filter->srcpad); gst_element_add_pad(GST_ELEMENT(filter), filter->srcpad);
...@@ -240,6 +247,10 @@ static void gst_kaldinnet2onlinedecoder_init( ...@@ -240,6 +247,10 @@ static void gst_kaldinnet2onlinedecoder_init(
filter->feature_config->Register(filter->simple_options); filter->feature_config->Register(filter->simple_options);
filter->nnet2_decoding_config->Register(filter->simple_options); filter->nnet2_decoding_config->Register(filter->simple_options);
// will be set later
filter->feature_info = NULL;
filter->sample_rate = 0;
// init properties from various Kaldi Opts // init properties from various Kaldi Opts
GstElementClass * klass = GST_ELEMENT_GET_CLASS(filter); GstElementClass * klass = GST_ELEMENT_GET_CLASS(filter);
...@@ -316,7 +327,6 @@ static void gst_kaldinnet2onlinedecoder_init( ...@@ -316,7 +327,6 @@ static void gst_kaldinnet2onlinedecoder_init(
} }
i += 1; i += 1;
} }
} }
static void gst_kaldinnet2onlinedecoder_set_property(GObject * object, static void gst_kaldinnet2onlinedecoder_set_property(GObject * object,
...@@ -533,7 +543,7 @@ static void gst_kaldinnet2onlinedecoder_loop( ...@@ -533,7 +543,7 @@ static void gst_kaldinnet2onlinedecoder_loop(
BaseFloat chunk_length_secs = 0.05; BaseFloat chunk_length_secs = 0.05;
BaseFloat traceback_period_secs = 1.0; BaseFloat traceback_period_secs = 1.0;
int32 chunk_length = int32(16000 * chunk_length_secs); int32 chunk_length = int32(filter->sample_rate * chunk_length_secs);
bool more_data = true; bool more_data = true;
while (more_data) { while (more_data) {
...@@ -556,7 +566,7 @@ static void gst_kaldinnet2onlinedecoder_loop( ...@@ -556,7 +566,7 @@ static void gst_kaldinnet2onlinedecoder_loop(
while (true) { while (true) {
more_data = filter->audio_source->Read(&wave_part); more_data = filter->audio_source->Read(&wave_part);
//GST_DEBUG_OBJECT(filter, "Read %d samples", wave_part.Dim()); //GST_DEBUG_OBJECT(filter, "Read %d samples", wave_part.Dim());
feature_pipeline.AcceptWaveform(16000, wave_part); feature_pipeline.AcceptWaveform(filter->sample_rate, wave_part);
//GST_DEBUG_OBJECT(filter, "Advancing decoding.."); //GST_DEBUG_OBJECT(filter, "Advancing decoding..");
if (!more_data) { if (!more_data) {
feature_pipeline.InputFinished(); feature_pipeline.InputFinished();
...@@ -600,6 +610,73 @@ static void gst_kaldinnet2onlinedecoder_loop( ...@@ -600,6 +610,73 @@ static void gst_kaldinnet2onlinedecoder_loop(
/* GstElement vmethod implementations */ /* GstElement vmethod implementations */
static gboolean
gst_kaldinnet2onlinedecoder_query (GstPad *pad, GstObject * parent, GstQuery * query) {
gboolean ret;
Gstkaldinnet2onlinedecoder *filter;
filter = GST_KALDINNET2ONLINEDECODER(parent);
switch (GST_QUERY_TYPE (query)) {
case GST_QUERY_CAPS: {
if (filter->feature_info == NULL) {
filter->feature_info = new OnlineNnet2FeaturePipelineInfo(*(filter->feature_config));
filter->sample_rate = (int) filter->feature_info->mfcc_opts.frame_opts.samp_freq;
}
GstCaps *new_caps = gst_caps_new_simple ("audio/x-raw",
"format", G_TYPE_STRING, "S16LE",
"rate", G_TYPE_INT, filter->sample_rate,
"channels", G_TYPE_INT, 1, NULL);
// GstPad *otherpad;
// GstCaps *temp, *caps, *filt, *tcaps;
// gint i;
//
// otherpad = (pad == filter->srcpad) ? filter->sinkpad :
// filter->srcpad;
// caps = gst_pad_get_allowed_caps (otherpad);
//
// gst_query_parse_caps (query, &filt);
//
// /* We support *any* samplerate, indifferent from the samplerate
// * supported by the linked elements on both sides. */
// for (i = 0; i < gst_caps_get_size (caps); i++) {
// GstStructure *structure = gst_caps_get_structure (caps, i);
//
// gst_structure_remove_field (structure, "rate");
// }
//
// /* make sure we only return results that intersect our
// * padtemplate */
// tcaps = gst_pad_get_pad_template_caps (pad);
// if (tcaps) {
// temp = gst_caps_intersect (caps, tcaps);
// gst_caps_unref (caps);
// gst_caps_unref (tcaps);
// caps = temp;
// }
// /* filter against the query filter when needed */
// if (filt) {
// temp = gst_caps_intersect (caps, filt);
// gst_caps_unref (caps);
// caps = temp;
// }
GST_DEBUG_OBJECT (filter, "Setting caps query result: %" GST_PTR_FORMAT, new_caps);
gst_query_set_caps_result (query, new_caps);
gst_caps_unref (new_caps);
ret = TRUE;
break;
}
default:
ret = gst_pad_query_default (pad, parent, query);
break;
}
return ret;
}
/* this function handles sink events */ /* this function handles sink events */
static gboolean gst_kaldinnet2onlinedecoder_sink_event(GstPad * pad, static gboolean gst_kaldinnet2onlinedecoder_sink_event(GstPad * pad,
GstObject * parent, GstObject * parent,
...@@ -647,9 +724,8 @@ static gboolean gst_kaldinnet2onlinedecoder_sink_event(GstPad * pad, ...@@ -647,9 +724,8 @@ static gboolean gst_kaldinnet2onlinedecoder_sink_event(GstPad * pad,
static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad, static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad,
GstObject * parent, GstObject * parent,
GstBuffer * buf) { GstBuffer * buf) {
Gstkaldinnet2onlinedecoder *filter; Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(parent);
filter = GST_KALDINNET2ONLINEDECODER(parent);
if (G_UNLIKELY(!filter->audio_source)) if (G_UNLIKELY(!filter->audio_source))
goto not_negotiated; goto not_negotiated;
...@@ -669,15 +745,21 @@ static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad, ...@@ -669,15 +745,21 @@ static GstFlowReturn gst_kaldinnet2onlinedecoder_chain(GstPad * pad,
} }
} }
static bool gst_kaldinnet2onlinedecoder_allocate(
static bool
gst_kaldinnet2onlinedecoder_allocate(
Gstkaldinnet2onlinedecoder * filter) { Gstkaldinnet2onlinedecoder * filter) {
if (!filter->feature_info) { if (!filter->decode_fst) {
GST_INFO_OBJECT(filter, "Loading Kaldi models and feature extractor"); GST_INFO_OBJECT(filter, "Loading Kaldi models and feature extractor");
filter->audio_source = new GstBufferSource(); filter->audio_source = new GstBufferSource();
filter->feature_info = new OnlineNnet2FeaturePipelineInfo( if (filter->feature_info == NULL) {
*(filter->feature_config)); filter->feature_info = new OnlineNnet2FeaturePipelineInfo(*(filter->feature_config));
filter->sample_rate = (int) filter->feature_info->mfcc_opts.frame_opts.samp_freq;
}
filter->sample_rate = (int) filter->feature_info->mfcc_opts.frame_opts.samp_freq;
filter->trans_model = new TransitionModel(); filter->trans_model = new TransitionModel();
filter->nnet = new nnet2::AmNnet(); filter->nnet = new nnet2::AmNnet();
...@@ -710,6 +792,7 @@ static bool gst_kaldinnet2onlinedecoder_deallocate( ...@@ -710,6 +792,7 @@ static bool gst_kaldinnet2onlinedecoder_deallocate(
static GstStateChangeReturn gst_kaldinnet2onlinedecoder_change_state( static GstStateChangeReturn gst_kaldinnet2onlinedecoder_change_state(
GstElement *element, GstStateChange transition) { GstElement *element, GstStateChange transition) {
GstStateChangeReturn ret = GST_STATE_CHANGE_SUCCESS; GstStateChangeReturn ret = GST_STATE_CHANGE_SUCCESS;
Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(element); Gstkaldinnet2onlinedecoder *filter = GST_KALDINNET2ONLINEDECODER(element);
......
...@@ -84,6 +84,8 @@ struct _Gstkaldinnet2onlinedecoder ...@@ -84,6 +84,8 @@ struct _Gstkaldinnet2onlinedecoder
GstPad *sinkpad, *srcpad; GstPad *sinkpad, *srcpad;
GstCaps *sink_caps;
gboolean silent; gboolean silent;
gboolean do_endpointing; gboolean do_endpointing;
GstBufferSource *audio_source; GstBufferSource *audio_source;
...@@ -102,6 +104,7 @@ struct _Gstkaldinnet2onlinedecoder ...@@ -102,6 +104,7 @@ struct _Gstkaldinnet2onlinedecoder
nnet2::AmNnet *nnet; nnet2::AmNnet *nnet;
fst::Fst<fst::StdArc> *decode_fst; fst::Fst<fst::StdArc> *decode_fst;
fst::SymbolTable *word_syms; fst::SymbolTable *word_syms;
int sample_rate;
}; };
struct _Gstkaldinnet2onlinedecoderClass struct _Gstkaldinnet2onlinedecoderClass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment