diff --git a/src/Makefile b/src/Makefile index 96eff7bb6e388662ba7abf0c5a6e5ed314d010af..dd569e5908bcf664f148aeef38ba575163da716b 100644 --- a/src/Makefile +++ b/src/Makefile @@ -23,7 +23,7 @@ EXTRA_LDLIBS = -pthread -lgstbase-1.0 -lgstcontroller-1.0 -lgstreamer-1.0 -lgobj #Kaldi shared libraries required by the GStreamer plugin EXTRA_LDLIBS += -lkaldi-online2 -lkaldi-lat -lkaldi-decoder -lkaldi-feat -lkaldi-transform \ -lkaldi-gmm -lkaldi-hmm \ - -lkaldi-tree -lkaldi-matrix -lkaldi-util -lkaldi-base -lkaldi-thread + -lkaldi-tree -lkaldi-matrix -lkaldi-util -lkaldi-base -lkaldi-thread -lkaldi-lm -lfstscript OBJFILES = gstkaldinnet2onlinedecoder.o simple-options-gst.o gst-audio-source.o kaldimarshal.o diff --git a/src/gstkaldinnet2onlinedecoder.cc b/src/gstkaldinnet2onlinedecoder.cc index a1866b72864d25e413a3fc286fcf268cbf2dbfae..8031a839bc7937bdcfff1088f1532dd2c3ee21f2 100644 --- a/src/gstkaldinnet2onlinedecoder.cc +++ b/src/gstkaldinnet2onlinedecoder.cc @@ -49,6 +49,7 @@ #include "fstext/fstext-lib.h" #include "lat/confidence.h" +#include <fst/script/project.h> namespace kaldi { @@ -73,6 +74,8 @@ enum { PROP_INVERSE_SCALE, PROP_LMWT_SCALE, PROP_CHUNK_LENGTH_IN_SECS, + PROP_LM_FST, + PROP_BIG_LM_CONST_ARPA, PROP_LAST }; @@ -221,6 +224,23 @@ static void gst_kaldinnet2onlinedecoder_class_init( DEFAULT_CHUNK_LENGTH_IN_SECS, (GParamFlags) G_PARAM_READWRITE)); + g_object_class_install_property( + gobject_class, + PROP_LM_FST, + g_param_spec_string( + "lm-fst", + "Language language model FST (G.fst), only needed when rescoring with the constant ARPA LM", + "Old LM as FST (G.fst)", "", (GParamFlags) G_PARAM_READWRITE)); + + g_object_class_install_property( + gobject_class, + PROP_BIG_LM_CONST_ARPA, + g_param_spec_string( + "big-lm-const-arpa", + "Big language model in constant ARPA format (typically G.carpa), to be used for rescoring final lattices. Also requires 'lm-fst' property", + "Big language model in constant ARPA format (typically G.carpa), to be used for rescoring final lattices. Also requires 'lm-fst' property", + "", (GParamFlags) G_PARAM_READWRITE)); + gst_kaldinnet2onlinedecoder_signals[PARTIAL_RESULT_SIGNAL] = g_signal_new( "partial-result", G_TYPE_FROM_CLASS(klass), G_SIGNAL_RUN_LAST, G_STRUCT_OFFSET(Gstkaldinnet2onlinedecoderClass, partial_result), @@ -300,6 +320,9 @@ static void gst_kaldinnet2onlinedecoder_init( filter->inverse_scale = FALSE; filter->chunk_length_in_secs = DEFAULT_CHUNK_LENGTH_IN_SECS; + filter->lm_fst_name = g_strdup(""); + filter->big_lm_const_arpa_name = g_strdup(""); + // init properties from various Kaldi Opts GstElementClass * klass = GST_ELEMENT_GET_CLASS(filter); @@ -414,6 +437,14 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object, case PROP_CHUNK_LENGTH_IN_SECS: filter->chunk_length_in_secs = g_value_get_float(value); break; + case PROP_LM_FST: + g_free(filter->lm_fst_name); + filter->lm_fst_name = g_value_dup_string(value); + break; + case PROP_BIG_LM_CONST_ARPA: + g_free(filter->big_lm_const_arpa_name); + filter->big_lm_const_arpa_name = g_value_dup_string(value); + break; case PROP_ADAPTATION_STATE: { if (G_VALUE_HOLDS_STRING(value)) { @@ -520,6 +551,12 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object, case PROP_CHUNK_LENGTH_IN_SECS: g_value_set_float(value, filter->chunk_length_in_secs); break; + case PROP_LM_FST: + g_value_set_string(value, filter->lm_fst_name); + break; + case PROP_BIG_LM_CONST_ARPA: + g_value_set_string(value, filter->big_lm_const_arpa_name); + break; case PROP_ADAPTATION_STATE: string_stream.clear(); if (filter->adaptation_state) { @@ -655,6 +692,62 @@ static void gst_kaldinnet2onlinedecoder_partial_result( } } +static bool gst_kaldinnet2onlinedecoder_rescore_big_lm( + Gstkaldinnet2onlinedecoder * filter, CompactLattice &clat, CompactLattice &result_lat) { + + Lattice tmp_lattice; + ConvertLattice(clat, &tmp_lattice); + // Before composing with the LM FST, we scale the lattice weights + // by the inverse of "lm_scale". We'll later scale by "lm_scale". + // We do it this way so we can determinize and it will give the + // right effect (taking the "best path" through the LM) regardless + // of the sign of lm_scale. + fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &tmp_lattice); + ArcSort(&tmp_lattice, fst::OLabelCompare<LatticeArc>()); + + Lattice composed_lat; + // Could just do, more simply: Compose(lat, lm_fst, &composed_lat); + // and not have lm_compose_cache at all. + // The command below is faster, though; it's constant not + // logarithmic in vocab size. + + TableCompose(tmp_lattice, *(filter->lm_fst), &composed_lat, filter->lm_compose_cache); + + Invert(&composed_lat); // make it so word labels are on the input. + CompactLattice determinized_lat; + DeterminizeLattice(composed_lat, &determinized_lat); + fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &determinized_lat); + if (determinized_lat.Start() == fst::kNoStateId) { + GST_INFO_OBJECT(filter, "Empty lattice (incompatible LM?)"); + return false; + } else { + fst::ScaleLattice(fst::GraphLatticeScale(1.0), &determinized_lat); + ArcSort(&determinized_lat, fst::OLabelCompare<CompactLatticeArc>()); + + // Wraps the ConstArpaLm format language model into FST. We re-create it + // for each lattice to prevent memory usage increasing with time. + ConstArpaLmDeterministicFst const_arpa_fst(*(filter->big_lm_const_arpa)); + + // Composes lattice with language model. + CompactLattice composed_clat; + ComposeCompactLatticeDeterministic(determinized_lat, + &const_arpa_fst, &composed_clat); + + // Determinizes the composed lattice. + Lattice composed_lat; + ConvertLattice(composed_clat, &composed_lat); + Invert(&composed_lat); + DeterminizeLattice(composed_lat, &result_lat); + fst::ScaleLattice(fst::GraphLatticeScale(1.0), &result_lat); + if (result_lat.Start() == fst::kNoStateId) { + GST_INFO_OBJECT(filter, "Empty lattice (incompatible LM?)"); + return false; + } + } + return true; +} + + static void gst_kaldinnet2onlinedecoder_loop( Gstkaldinnet2onlinedecoder * filter) { @@ -709,6 +802,14 @@ static void gst_kaldinnet2onlinedecoder_loop( bool end_of_utterance = true; decoder.GetLattice(end_of_utterance, &clat); GST_DEBUG_OBJECT(filter, "Lattice done"); + if ((filter->lm_fst != NULL) && (filter->big_lm_const_arpa != NULL)) { + GST_DEBUG_OBJECT(filter, "Rescoring lattice with a big LM"); + CompactLattice rescored_lat; + if (gst_kaldinnet2onlinedecoder_rescore_big_lm(filter, clat, rescored_lat)) { + clat = rescored_lat; + } + } + double tot_like = 0.0; int64 num_frames = 0; @@ -880,6 +981,56 @@ gst_kaldinnet2onlinedecoder_allocate( filter->adaptation_state = new OnlineIvectorExtractorAdaptationState( filter->feature_info->ivector_extractor_info); + if ((strlen(filter->lm_fst_name) > 0) && + (strlen(filter->big_lm_const_arpa_name) > 0)) { + GST_DEBUG_OBJECT(filter, "Loading models for LM rescoring with a big LM"); + fst::script::MutableFstClass *fst = + fst::script::MutableFstClass::Read(filter->lm_fst_name, true); + fst::script::Project(fst, fst::PROJECT_OUTPUT); + + const fst::Fst<fst::StdArc> *tmp_fst = fst->GetFst<fst::StdArc>(); + + fst::VectorFst<fst::StdArc> *std_lm_fst = new fst::VectorFst<fst::StdArc>(*tmp_fst); + + if (std_lm_fst->Properties(fst::kILabelSorted, true) == 0) { + // Make sure LM is sorted on ilabel. + fst::ILabelCompare<fst::StdArc> ilabel_comp; + fst::ArcSort(std_lm_fst, ilabel_comp); + } + + // mapped_fst is the LM fst interpreted using the LatticeWeight semiring, + // with all the cost on the first member of the pair (since it's a graph + // weight). + int32 num_states_cache = 50000; + fst::CacheOptions cache_opts(true, num_states_cache); + fst::StdToLatticeMapper<BaseFloat> mapper; + filter->lm_fst = new fst::MapFst<fst::StdArc, LatticeArc, + fst::StdToLatticeMapper<BaseFloat> >(*std_lm_fst, mapper, cache_opts); + delete std_lm_fst; + delete fst; + // FIXME: maybe? + //delete tmp_fst; + + // The next fifteen or so lines are a kind of optimization and + // can be ignored if you just want to understand what is going on. + // Change the options for TableCompose to match the input + // (because it's the arcs of the LM FST we want to do lookup + // on). + fst::TableComposeOptions compose_opts(fst::TableMatcherOptions(), + true, fst::SEQUENCE_FILTER, + fst::MATCH_INPUT); + + // The following is an optimization for the TableCompose + // composition: it stores certain tables that enable fast + // lookup of arcs during composition. + filter->lm_compose_cache = new fst::TableComposeCache<fst::Fst<LatticeArc> >(compose_opts); + + GST_DEBUG_OBJECT(filter, "Loading big LM in constant ARPA format"); + filter->big_lm_const_arpa = new ConstArpaLm(); + ReadKaldiObject(filter->big_lm_const_arpa_name, filter->big_lm_const_arpa); + + } + } return true; } diff --git a/src/gstkaldinnet2onlinedecoder.h b/src/gstkaldinnet2onlinedecoder.h index e1fc7e6aa33cfcdd2a86260d2ae8e3748541e377..f14eea3ad8126888ac520f998b7e6c9387402e90 100644 --- a/src/gstkaldinnet2onlinedecoder.h +++ b/src/gstkaldinnet2onlinedecoder.h @@ -31,6 +31,7 @@ #include "online2/online-endpoint.h" #include "fstext/fstext-lib.h" #include "lat/lattice-functions.h" +#include "lm/const-arpa-lm.h" namespace kaldi { @@ -82,6 +83,13 @@ struct _Gstkaldinnet2onlinedecoder { gboolean decoding; float chunk_length_in_secs; OnlineIvectorExtractorAdaptationState *adaptation_state; + + // The following are needed for optional LM rescoring with a "big" LM + gchar* lm_fst_name; + gchar* big_lm_const_arpa_name; + fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > *lm_fst; + fst::TableComposeCache<fst::Fst<LatticeArc> > *lm_compose_cache; + ConstArpaLm *big_lm_const_arpa; }; struct _Gstkaldinnet2onlinedecoderClass {