diff --git a/src/gstkaldinnet2onlinedecoder.cc b/src/gstkaldinnet2onlinedecoder.cc index 70f9aa38093c4e6a12480e1d3fc253bc755f1d91..a0154aa3f67998903470f8cf2ed2590e69fef348 100644 --- a/src/gstkaldinnet2onlinedecoder.cc +++ b/src/gstkaldinnet2onlinedecoder.cc @@ -306,7 +306,11 @@ static void gst_kaldinnet2onlinedecoder_init( filter->endpoint_config = new OnlineEndpointConfig(); filter->feature_config = new OnlineNnet2FeaturePipelineConfig(); +#ifdef THREADED_DECODER + filter->nnet2_decoding_config = new OnlineNnet2DecodingThreadedConfig(); +#else filter->nnet2_decoding_config = new OnlineNnet2DecodingConfig(); +#endif filter->endpoint_config->Register(filter->simple_options); filter->feature_config->Register(filter->simple_options); @@ -616,8 +620,13 @@ static void gst_kaldinnet2onlinedecoder_final_result( CompactLattice best_path_clat; if (filter->inverse_scale) { +#ifdef THREADED_DECODER + BaseFloat inv_acoustic_scale = 1.0 / filter-> + nnet2_decoding_config->acoustic_scale; +#else BaseFloat inv_acoustic_scale = 1.0 / filter->nnet2_decoding_config-> decodable_opts.acoustic_scale; +#endif fst::ScaleLattice(fst::AcousticLatticeScale(inv_acoustic_scale), &clat); } @@ -666,6 +675,7 @@ static void gst_kaldinnet2onlinedecoder_final_result( } } +#ifndef THREADED_DECODER // can't have partial results in threaded decoder static void gst_kaldinnet2onlinedecoder_partial_result( Gstkaldinnet2onlinedecoder * filter, const Lattice lat) { LatticeWeight weight; @@ -691,6 +701,7 @@ static void gst_kaldinnet2onlinedecoder_partial_result( sentence.str().c_str()); } } +#endif static bool gst_kaldinnet2onlinedecoder_rescore_big_lm( Gstkaldinnet2onlinedecoder * filter, CompactLattice &clat, CompactLattice &result_lat) { @@ -752,55 +763,89 @@ static void gst_kaldinnet2onlinedecoder_loop( Gstkaldinnet2onlinedecoder * filter) { GST_DEBUG_OBJECT(filter, "Starting decoding loop.."); +#ifndef THREADED_DECODER // used for partial result, unavailable with threading BaseFloat traceback_period_secs = 1.0; +#endif int32 chunk_length = int32(filter->sample_rate * filter->chunk_length_in_secs); bool more_data = true; while (more_data) { + +#ifdef THREADED_DECODER + SingleUtteranceNnet2DecoderThreaded decoder(*(filter->nnet2_decoding_config), + *(filter->trans_model), *(filter->nnet), + *(filter->decode_fst), + *(filter->feature_info), + *(filter->adaptation_state)); +#else OnlineNnet2FeaturePipeline feature_pipeline(*(filter->feature_info)); feature_pipeline.SetAdaptationState(*(filter->adaptation_state)); - SingleUtteranceNnet2Decoder decoder(*(filter->nnet2_decoding_config), *(filter->trans_model), *(filter->nnet), *(filter->decode_fst), &feature_pipeline); +#endif Vector<BaseFloat> wave_part = Vector<BaseFloat>(chunk_length); GST_DEBUG_OBJECT(filter, "Reading audio in %d sample chunks...", wave_part.Dim()); +#ifndef THREADED_DECODER // used for partial result, unavailable with threading BaseFloat last_traceback = 0.0; +#endif BaseFloat num_seconds_decoded = 0.0; while (true) { more_data = filter->audio_source->Read(&wave_part); +#ifdef THREADED_DECODER + decoder.AcceptWaveform(filter->sample_rate, wave_part); +#else feature_pipeline.AcceptWaveform(filter->sample_rate, wave_part); +#endif if (!more_data) { +#ifdef THREADED_DECODER + decoder.InputFinished(); +#else feature_pipeline.InputFinished(); +#endif } +#ifndef THREADED_DECODER // the threaded decoder just starts when it gets data decoder.AdvanceDecoding(); +#endif if (!more_data) { break; } if (filter->do_endpointing && decoder.EndpointDetected(*(filter->endpoint_config))) { +#ifdef THREADED_DECODER // must stop it (the non-threaded is already idle) + decoder.TerminateDecoding(); +#endif GST_DEBUG_OBJECT(filter, "Endpoint detected!"); break; } num_seconds_decoded += filter->chunk_length_in_secs; +#ifndef THREADED_DECODER // I think it's only relevant to the single-thread version if (num_seconds_decoded - last_traceback > traceback_period_secs) { Lattice lat; decoder.GetBestPath(false, &lat); gst_kaldinnet2onlinedecoder_partial_result(filter, lat); last_traceback += traceback_period_secs; } +#endif } if (num_seconds_decoded > 0.1) { +#ifdef THREADED_DECODER // joining the decoding thread + decoder.Wait(); +#endif GST_DEBUG_OBJECT(filter, "Getting lattice.."); decoder.FinalizeDecoding(); CompactLattice clat; bool end_of_utterance = true; +#ifdef THREADED_DECODER + decoder.GetLattice(end_of_utterance, &clat, NULL); +#else decoder.GetLattice(end_of_utterance, &clat); +#endif GST_DEBUG_OBJECT(filter, "Lattice done"); if ((filter->lm_fst != NULL) && (filter->big_lm_const_arpa != NULL)) { GST_DEBUG_OBJECT(filter, "Rescoring lattice with a big LM"); @@ -818,7 +863,11 @@ static void gst_kaldinnet2onlinedecoder_loop( &tot_like, &num_words); if (num_words > 0) { // Only update adaptation state if the utterance was not empty +#ifdef THREADED_DECODER + decoder.GetAdaptationState(filter->adaptation_state); +#else feature_pipeline.GetAdaptationState(filter->adaptation_state); +#endif } } else { GST_DEBUG_OBJECT(filter, "Less than 0.1 seconds decoded, discarding"); diff --git a/src/gstkaldinnet2onlinedecoder.h b/src/gstkaldinnet2onlinedecoder.h index f14eea3ad8126888ac520f998b7e6c9387402e90..dfb28f41772a58d84c6515c05d53d1d09341a0b2 100644 --- a/src/gstkaldinnet2onlinedecoder.h +++ b/src/gstkaldinnet2onlinedecoder.h @@ -25,7 +25,12 @@ #include "./simple-options-gst.h" #include "./gst-audio-source.h" +#ifdef THREADED_DECODER +#include "online2/online-nnet2-decoding-threaded.h" +#else #include "online2/online-nnet2-decoding.h" +#endif + #include "online2/onlinebin-util.h" #include "online2/online-timing.h" #include "online2/online-endpoint.h" @@ -72,7 +77,11 @@ struct _Gstkaldinnet2onlinedecoder { SimpleOptionsGst *simple_options; OnlineEndpointConfig *endpoint_config; OnlineNnet2FeaturePipelineConfig *feature_config; +#ifdef THREADED_DECODER + OnlineNnet2DecodingThreadedConfig *nnet2_decoding_config; +#else OnlineNnet2DecodingConfig *nnet2_decoding_config; +#endif OnlineNnet2FeaturePipelineInfo *feature_info; TransitionModel *trans_model;