Skip to content
Snippets Groups Projects
Commit 3be2d7ac authored by tanel's avatar tanel
Browse files

Implemented optional rescoring with (large) a 'constant ARPA' LM

parent 0b91c350
No related branches found
No related tags found
No related merge requests found
...@@ -23,7 +23,7 @@ EXTRA_LDLIBS = -pthread -lgstbase-1.0 -lgstcontroller-1.0 -lgstreamer-1.0 -lgobj ...@@ -23,7 +23,7 @@ EXTRA_LDLIBS = -pthread -lgstbase-1.0 -lgstcontroller-1.0 -lgstreamer-1.0 -lgobj
#Kaldi shared libraries required by the GStreamer plugin #Kaldi shared libraries required by the GStreamer plugin
EXTRA_LDLIBS += -lkaldi-online2 -lkaldi-lat -lkaldi-decoder -lkaldi-feat -lkaldi-transform \ EXTRA_LDLIBS += -lkaldi-online2 -lkaldi-lat -lkaldi-decoder -lkaldi-feat -lkaldi-transform \
-lkaldi-gmm -lkaldi-hmm \ -lkaldi-gmm -lkaldi-hmm \
-lkaldi-tree -lkaldi-matrix -lkaldi-util -lkaldi-base -lkaldi-thread -lkaldi-tree -lkaldi-matrix -lkaldi-util -lkaldi-base -lkaldi-thread -lkaldi-lm -lfstscript
OBJFILES = gstkaldinnet2onlinedecoder.o simple-options-gst.o gst-audio-source.o kaldimarshal.o OBJFILES = gstkaldinnet2onlinedecoder.o simple-options-gst.o gst-audio-source.o kaldimarshal.o
......
...@@ -49,6 +49,7 @@ ...@@ -49,6 +49,7 @@
#include "fstext/fstext-lib.h" #include "fstext/fstext-lib.h"
#include "lat/confidence.h" #include "lat/confidence.h"
#include <fst/script/project.h>
namespace kaldi { namespace kaldi {
...@@ -73,6 +74,8 @@ enum { ...@@ -73,6 +74,8 @@ enum {
PROP_INVERSE_SCALE, PROP_INVERSE_SCALE,
PROP_LMWT_SCALE, PROP_LMWT_SCALE,
PROP_CHUNK_LENGTH_IN_SECS, PROP_CHUNK_LENGTH_IN_SECS,
PROP_LM_FST,
PROP_BIG_LM_CONST_ARPA,
PROP_LAST PROP_LAST
}; };
...@@ -221,6 +224,23 @@ static void gst_kaldinnet2onlinedecoder_class_init( ...@@ -221,6 +224,23 @@ static void gst_kaldinnet2onlinedecoder_class_init(
DEFAULT_CHUNK_LENGTH_IN_SECS, DEFAULT_CHUNK_LENGTH_IN_SECS,
(GParamFlags) G_PARAM_READWRITE)); (GParamFlags) G_PARAM_READWRITE));
g_object_class_install_property(
gobject_class,
PROP_LM_FST,
g_param_spec_string(
"lm-fst",
"Language language model FST (G.fst), only needed when rescoring with the constant ARPA LM",
"Old LM as FST (G.fst)", "", (GParamFlags) G_PARAM_READWRITE));
g_object_class_install_property(
gobject_class,
PROP_BIG_LM_CONST_ARPA,
g_param_spec_string(
"big-lm-const-arpa",
"Big language model in constant ARPA format (typically G.carpa), to be used for rescoring final lattices. Also requires 'lm-fst' property",
"Big language model in constant ARPA format (typically G.carpa), to be used for rescoring final lattices. Also requires 'lm-fst' property",
"", (GParamFlags) G_PARAM_READWRITE));
gst_kaldinnet2onlinedecoder_signals[PARTIAL_RESULT_SIGNAL] = g_signal_new( gst_kaldinnet2onlinedecoder_signals[PARTIAL_RESULT_SIGNAL] = g_signal_new(
"partial-result", G_TYPE_FROM_CLASS(klass), G_SIGNAL_RUN_LAST, "partial-result", G_TYPE_FROM_CLASS(klass), G_SIGNAL_RUN_LAST,
G_STRUCT_OFFSET(Gstkaldinnet2onlinedecoderClass, partial_result), G_STRUCT_OFFSET(Gstkaldinnet2onlinedecoderClass, partial_result),
...@@ -300,6 +320,9 @@ static void gst_kaldinnet2onlinedecoder_init( ...@@ -300,6 +320,9 @@ static void gst_kaldinnet2onlinedecoder_init(
filter->inverse_scale = FALSE; filter->inverse_scale = FALSE;
filter->chunk_length_in_secs = DEFAULT_CHUNK_LENGTH_IN_SECS; filter->chunk_length_in_secs = DEFAULT_CHUNK_LENGTH_IN_SECS;
filter->lm_fst_name = g_strdup("");
filter->big_lm_const_arpa_name = g_strdup("");
// init properties from various Kaldi Opts // init properties from various Kaldi Opts
GstElementClass * klass = GST_ELEMENT_GET_CLASS(filter); GstElementClass * klass = GST_ELEMENT_GET_CLASS(filter);
...@@ -414,6 +437,14 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object, ...@@ -414,6 +437,14 @@ static void gst_kaldinnet2onlinedecoder_set_property(GObject * object,
case PROP_CHUNK_LENGTH_IN_SECS: case PROP_CHUNK_LENGTH_IN_SECS:
filter->chunk_length_in_secs = g_value_get_float(value); filter->chunk_length_in_secs = g_value_get_float(value);
break; break;
case PROP_LM_FST:
g_free(filter->lm_fst_name);
filter->lm_fst_name = g_value_dup_string(value);
break;
case PROP_BIG_LM_CONST_ARPA:
g_free(filter->big_lm_const_arpa_name);
filter->big_lm_const_arpa_name = g_value_dup_string(value);
break;
case PROP_ADAPTATION_STATE: case PROP_ADAPTATION_STATE:
{ {
if (G_VALUE_HOLDS_STRING(value)) { if (G_VALUE_HOLDS_STRING(value)) {
...@@ -520,6 +551,12 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object, ...@@ -520,6 +551,12 @@ static void gst_kaldinnet2onlinedecoder_get_property(GObject * object,
case PROP_CHUNK_LENGTH_IN_SECS: case PROP_CHUNK_LENGTH_IN_SECS:
g_value_set_float(value, filter->chunk_length_in_secs); g_value_set_float(value, filter->chunk_length_in_secs);
break; break;
case PROP_LM_FST:
g_value_set_string(value, filter->lm_fst_name);
break;
case PROP_BIG_LM_CONST_ARPA:
g_value_set_string(value, filter->big_lm_const_arpa_name);
break;
case PROP_ADAPTATION_STATE: case PROP_ADAPTATION_STATE:
string_stream.clear(); string_stream.clear();
if (filter->adaptation_state) { if (filter->adaptation_state) {
...@@ -655,6 +692,62 @@ static void gst_kaldinnet2onlinedecoder_partial_result( ...@@ -655,6 +692,62 @@ static void gst_kaldinnet2onlinedecoder_partial_result(
} }
} }
static bool gst_kaldinnet2onlinedecoder_rescore_big_lm(
Gstkaldinnet2onlinedecoder * filter, CompactLattice &clat, CompactLattice &result_lat) {
Lattice tmp_lattice;
ConvertLattice(clat, &tmp_lattice);
// Before composing with the LM FST, we scale the lattice weights
// by the inverse of "lm_scale". We'll later scale by "lm_scale".
// We do it this way so we can determinize and it will give the
// right effect (taking the "best path" through the LM) regardless
// of the sign of lm_scale.
fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &tmp_lattice);
ArcSort(&tmp_lattice, fst::OLabelCompare<LatticeArc>());
Lattice composed_lat;
// Could just do, more simply: Compose(lat, lm_fst, &composed_lat);
// and not have lm_compose_cache at all.
// The command below is faster, though; it's constant not
// logarithmic in vocab size.
TableCompose(tmp_lattice, *(filter->lm_fst), &composed_lat, filter->lm_compose_cache);
Invert(&composed_lat); // make it so word labels are on the input.
CompactLattice determinized_lat;
DeterminizeLattice(composed_lat, &determinized_lat);
fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &determinized_lat);
if (determinized_lat.Start() == fst::kNoStateId) {
GST_INFO_OBJECT(filter, "Empty lattice (incompatible LM?)");
return false;
} else {
fst::ScaleLattice(fst::GraphLatticeScale(1.0), &determinized_lat);
ArcSort(&determinized_lat, fst::OLabelCompare<CompactLatticeArc>());
// Wraps the ConstArpaLm format language model into FST. We re-create it
// for each lattice to prevent memory usage increasing with time.
ConstArpaLmDeterministicFst const_arpa_fst(*(filter->big_lm_const_arpa));
// Composes lattice with language model.
CompactLattice composed_clat;
ComposeCompactLatticeDeterministic(determinized_lat,
&const_arpa_fst, &composed_clat);
// Determinizes the composed lattice.
Lattice composed_lat;
ConvertLattice(composed_clat, &composed_lat);
Invert(&composed_lat);
DeterminizeLattice(composed_lat, &result_lat);
fst::ScaleLattice(fst::GraphLatticeScale(1.0), &result_lat);
if (result_lat.Start() == fst::kNoStateId) {
GST_INFO_OBJECT(filter, "Empty lattice (incompatible LM?)");
return false;
}
}
return true;
}
static void gst_kaldinnet2onlinedecoder_loop( static void gst_kaldinnet2onlinedecoder_loop(
Gstkaldinnet2onlinedecoder * filter) { Gstkaldinnet2onlinedecoder * filter) {
...@@ -709,6 +802,14 @@ static void gst_kaldinnet2onlinedecoder_loop( ...@@ -709,6 +802,14 @@ static void gst_kaldinnet2onlinedecoder_loop(
bool end_of_utterance = true; bool end_of_utterance = true;
decoder.GetLattice(end_of_utterance, &clat); decoder.GetLattice(end_of_utterance, &clat);
GST_DEBUG_OBJECT(filter, "Lattice done"); GST_DEBUG_OBJECT(filter, "Lattice done");
if ((filter->lm_fst != NULL) && (filter->big_lm_const_arpa != NULL)) {
GST_DEBUG_OBJECT(filter, "Rescoring lattice with a big LM");
CompactLattice rescored_lat;
if (gst_kaldinnet2onlinedecoder_rescore_big_lm(filter, clat, rescored_lat)) {
clat = rescored_lat;
}
}
double tot_like = 0.0; double tot_like = 0.0;
int64 num_frames = 0; int64 num_frames = 0;
...@@ -880,6 +981,56 @@ gst_kaldinnet2onlinedecoder_allocate( ...@@ -880,6 +981,56 @@ gst_kaldinnet2onlinedecoder_allocate(
filter->adaptation_state = new OnlineIvectorExtractorAdaptationState( filter->adaptation_state = new OnlineIvectorExtractorAdaptationState(
filter->feature_info->ivector_extractor_info); filter->feature_info->ivector_extractor_info);
if ((strlen(filter->lm_fst_name) > 0) &&
(strlen(filter->big_lm_const_arpa_name) > 0)) {
GST_DEBUG_OBJECT(filter, "Loading models for LM rescoring with a big LM");
fst::script::MutableFstClass *fst =
fst::script::MutableFstClass::Read(filter->lm_fst_name, true);
fst::script::Project(fst, fst::PROJECT_OUTPUT);
const fst::Fst<fst::StdArc> *tmp_fst = fst->GetFst<fst::StdArc>();
fst::VectorFst<fst::StdArc> *std_lm_fst = new fst::VectorFst<fst::StdArc>(*tmp_fst);
if (std_lm_fst->Properties(fst::kILabelSorted, true) == 0) {
// Make sure LM is sorted on ilabel.
fst::ILabelCompare<fst::StdArc> ilabel_comp;
fst::ArcSort(std_lm_fst, ilabel_comp);
}
// mapped_fst is the LM fst interpreted using the LatticeWeight semiring,
// with all the cost on the first member of the pair (since it's a graph
// weight).
int32 num_states_cache = 50000;
fst::CacheOptions cache_opts(true, num_states_cache);
fst::StdToLatticeMapper<BaseFloat> mapper;
filter->lm_fst = new fst::MapFst<fst::StdArc, LatticeArc,
fst::StdToLatticeMapper<BaseFloat> >(*std_lm_fst, mapper, cache_opts);
delete std_lm_fst;
delete fst;
// FIXME: maybe?
//delete tmp_fst;
// The next fifteen or so lines are a kind of optimization and
// can be ignored if you just want to understand what is going on.
// Change the options for TableCompose to match the input
// (because it's the arcs of the LM FST we want to do lookup
// on).
fst::TableComposeOptions compose_opts(fst::TableMatcherOptions(),
true, fst::SEQUENCE_FILTER,
fst::MATCH_INPUT);
// The following is an optimization for the TableCompose
// composition: it stores certain tables that enable fast
// lookup of arcs during composition.
filter->lm_compose_cache = new fst::TableComposeCache<fst::Fst<LatticeArc> >(compose_opts);
GST_DEBUG_OBJECT(filter, "Loading big LM in constant ARPA format");
filter->big_lm_const_arpa = new ConstArpaLm();
ReadKaldiObject(filter->big_lm_const_arpa_name, filter->big_lm_const_arpa);
}
} }
return true; return true;
} }
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "online2/online-endpoint.h" #include "online2/online-endpoint.h"
#include "fstext/fstext-lib.h" #include "fstext/fstext-lib.h"
#include "lat/lattice-functions.h" #include "lat/lattice-functions.h"
#include "lm/const-arpa-lm.h"
namespace kaldi { namespace kaldi {
...@@ -82,6 +83,13 @@ struct _Gstkaldinnet2onlinedecoder { ...@@ -82,6 +83,13 @@ struct _Gstkaldinnet2onlinedecoder {
gboolean decoding; gboolean decoding;
float chunk_length_in_secs; float chunk_length_in_secs;
OnlineIvectorExtractorAdaptationState *adaptation_state; OnlineIvectorExtractorAdaptationState *adaptation_state;
// The following are needed for optional LM rescoring with a "big" LM
gchar* lm_fst_name;
gchar* big_lm_const_arpa_name;
fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > *lm_fst;
fst::TableComposeCache<fst::Fst<LatticeArc> > *lm_compose_cache;
ConstArpaLm *big_lm_const_arpa;
}; };
struct _Gstkaldinnet2onlinedecoderClass { struct _Gstkaldinnet2onlinedecoderClass {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment