Commit ff37cb00 authored by Amit Beka's avatar Amit Beka
Browse files

add support for threaded nnet2 decoding



recent kaldi version add a new decoder class of threaded processing,
with a different API than the previous (SingleUtteranceNnet2DecoderThreaded
instead of SingleUtteranceNnet2Decoder).

this commit adds compile-time support to the threaded decoder, by
ifdef'ing all the relevant API changes. future commit should make it
prettier, but for now, in order to avoid everybody updating kaldi source
code, this seems better.

compile the threaded decoder with make CPPFLAGS=-DTHREADED_DECODER.
Signed-off-by: default avatarAmit Beka <amit.beka@gmail.com>
parent c938fcf0
......@@ -306,7 +306,11 @@ static void gst_kaldinnet2onlinedecoder_init(
filter->endpoint_config = new OnlineEndpointConfig();
filter->feature_config = new OnlineNnet2FeaturePipelineConfig();
#ifdef THREADED_DECODER
filter->nnet2_decoding_config = new OnlineNnet2DecodingThreadedConfig();
#else
filter->nnet2_decoding_config = new OnlineNnet2DecodingConfig();
#endif
filter->endpoint_config->Register(filter->simple_options);
filter->feature_config->Register(filter->simple_options);
......@@ -616,8 +620,13 @@ static void gst_kaldinnet2onlinedecoder_final_result(
CompactLattice best_path_clat;
if (filter->inverse_scale) {
#ifdef THREADED_DECODER
BaseFloat inv_acoustic_scale = 1.0 / filter->
nnet2_decoding_config->acoustic_scale;
#else
BaseFloat inv_acoustic_scale = 1.0 / filter->nnet2_decoding_config->
decodable_opts.acoustic_scale;
#endif
fst::ScaleLattice(fst::AcousticLatticeScale(inv_acoustic_scale), &clat);
}
......@@ -666,6 +675,7 @@ static void gst_kaldinnet2onlinedecoder_final_result(
}
}
#ifndef THREADED_DECODER // can't have partial results in threaded decoder
static void gst_kaldinnet2onlinedecoder_partial_result(
Gstkaldinnet2onlinedecoder * filter, const Lattice lat) {
LatticeWeight weight;
......@@ -691,6 +701,7 @@ static void gst_kaldinnet2onlinedecoder_partial_result(
sentence.str().c_str());
}
}
#endif
static bool gst_kaldinnet2onlinedecoder_rescore_big_lm(
Gstkaldinnet2onlinedecoder * filter, CompactLattice &clat, CompactLattice &result_lat) {
......@@ -752,55 +763,89 @@ static void gst_kaldinnet2onlinedecoder_loop(
Gstkaldinnet2onlinedecoder * filter) {
GST_DEBUG_OBJECT(filter, "Starting decoding loop..");
#ifndef THREADED_DECODER // used for partial result, unavailable with threading
BaseFloat traceback_period_secs = 1.0;
#endif
int32 chunk_length = int32(filter->sample_rate * filter->chunk_length_in_secs);
bool more_data = true;
while (more_data) {
#ifdef THREADED_DECODER
SingleUtteranceNnet2DecoderThreaded decoder(*(filter->nnet2_decoding_config),
*(filter->trans_model), *(filter->nnet),
*(filter->decode_fst),
*(filter->feature_info),
*(filter->adaptation_state));
#else
OnlineNnet2FeaturePipeline feature_pipeline(*(filter->feature_info));
feature_pipeline.SetAdaptationState(*(filter->adaptation_state));
SingleUtteranceNnet2Decoder decoder(*(filter->nnet2_decoding_config),
*(filter->trans_model), *(filter->nnet),
*(filter->decode_fst),
&feature_pipeline);
#endif
Vector<BaseFloat> wave_part = Vector<BaseFloat>(chunk_length);
GST_DEBUG_OBJECT(filter, "Reading audio in %d sample chunks...",
wave_part.Dim());
#ifndef THREADED_DECODER // used for partial result, unavailable with threading
BaseFloat last_traceback = 0.0;
#endif
BaseFloat num_seconds_decoded = 0.0;
while (true) {
more_data = filter->audio_source->Read(&wave_part);
#ifdef THREADED_DECODER
decoder.AcceptWaveform(filter->sample_rate, wave_part);
#else
feature_pipeline.AcceptWaveform(filter->sample_rate, wave_part);
#endif
if (!more_data) {
#ifdef THREADED_DECODER
decoder.InputFinished();
#else
feature_pipeline.InputFinished();
#endif
}
#ifndef THREADED_DECODER // the threaded decoder just starts when it gets data
decoder.AdvanceDecoding();
#endif
if (!more_data) {
break;
}
if (filter->do_endpointing
&& decoder.EndpointDetected(*(filter->endpoint_config))) {
#ifdef THREADED_DECODER // must stop it (the non-threaded is already idle)
decoder.TerminateDecoding();
#endif
GST_DEBUG_OBJECT(filter, "Endpoint detected!");
break;
}
num_seconds_decoded += filter->chunk_length_in_secs;
#ifndef THREADED_DECODER // I think it's only relevant to the single-thread version
if (num_seconds_decoded - last_traceback > traceback_period_secs) {
Lattice lat;
decoder.GetBestPath(false, &lat);
gst_kaldinnet2onlinedecoder_partial_result(filter, lat);
last_traceback += traceback_period_secs;
}
#endif
}
if (num_seconds_decoded > 0.1) {
#ifdef THREADED_DECODER // joining the decoding thread
decoder.Wait();
#endif
GST_DEBUG_OBJECT(filter, "Getting lattice..");
decoder.FinalizeDecoding();
CompactLattice clat;
bool end_of_utterance = true;
#ifdef THREADED_DECODER
decoder.GetLattice(end_of_utterance, &clat, NULL);
#else
decoder.GetLattice(end_of_utterance, &clat);
#endif
GST_DEBUG_OBJECT(filter, "Lattice done");
if ((filter->lm_fst != NULL) && (filter->big_lm_const_arpa != NULL)) {
GST_DEBUG_OBJECT(filter, "Rescoring lattice with a big LM");
......@@ -818,7 +863,11 @@ static void gst_kaldinnet2onlinedecoder_loop(
&tot_like, &num_words);
if (num_words > 0) {
// Only update adaptation state if the utterance was not empty
#ifdef THREADED_DECODER
decoder.GetAdaptationState(filter->adaptation_state);
#else
feature_pipeline.GetAdaptationState(filter->adaptation_state);
#endif
}
} else {
GST_DEBUG_OBJECT(filter, "Less than 0.1 seconds decoded, discarding");
......
......@@ -25,7 +25,12 @@
#include "./simple-options-gst.h"
#include "./gst-audio-source.h"
#ifdef THREADED_DECODER
#include "online2/online-nnet2-decoding-threaded.h"
#else
#include "online2/online-nnet2-decoding.h"
#endif
#include "online2/onlinebin-util.h"
#include "online2/online-timing.h"
#include "online2/online-endpoint.h"
......@@ -72,7 +77,11 @@ struct _Gstkaldinnet2onlinedecoder {
SimpleOptionsGst *simple_options;
OnlineEndpointConfig *endpoint_config;
OnlineNnet2FeaturePipelineConfig *feature_config;
#ifdef THREADED_DECODER
OnlineNnet2DecodingThreadedConfig *nnet2_decoding_config;
#else
OnlineNnet2DecodingConfig *nnet2_decoding_config;
#endif
OnlineNnet2FeaturePipelineInfo *feature_info;
TransitionModel *trans_model;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment