Skip to content
Snippets Groups Projects
librocio_slu.cc 10.73 KiB
/*  SLU for Rocio XML  */
/*  FRED 0215  */

#include <string>
#include <vector>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>

#include "librocio_slu.h"

extern "C" {

#include "lia_liblex.h"

/*................................................................*/

#define TailleLigne     80000

#define True    1
#define False   0

void ERREUR(const char *ch1, const char *ch2)
{
    fprintf(stderr,"ERREUR : %s %s\n",ch1,ch2);
    exit(0);
}

void ERREURd(const char *ch1, int i)
{
    fprintf(stderr,"ERREUR : %s %d\n",ch1,i);
    exit(0);
}

/*................................................................*/

#define MAX_FIELD	60000

static const char *CHglouton="<joker>";
static const char *CHepsilon="<epsilon>";

#define IEPSILON	0
#define IGLOU		1
#define PENALEPSILON	50
#define PENALGLOU	100
#define WINLENGTH       30

typedef struct
        {
        int index;
        char select;
        } type_outword;

/*................................................................*/


fst::StdVectorFst *build_fst_words(slu_t* slu, char **words, int num_words, int lexidword, char *prevword, type_outword *t_outword, int *nbword)
{
    int i, nb,numstate,code,uncertain,deca;
    fst::StdVectorFst *input;
    char *pt;
    input = new fst::StdVectorFst;
    input->AddState();
    input->SetStart(0);
    numstate=nb=0;

    /* add the previous words */
    if (prevword)
    {
        for(pt=strtok(prevword," \t\n");pt;pt=strtok(NULL," \t\n"))
        {
            if (!strncmp(pt,"**",2)) { uncertain=True; deca=2; t_outword[nb].select=3; } else { uncertain=False; deca=0; t_outword[nb].select=1; }
            if (word2code(lexidword,pt+deca,&code))
            {
                t_outword[nb++].index=code; if (nb==MAX_FIELD) ERREUR("cste MAX_FIELD too small","");
                input->AddState();
                input->AddArc(numstate,fst::StdArc(code,code,0,numstate+1));
                if (uncertain)
                {
                    input->AddArc(numstate,fst::StdArc(code,IGLOU,PENALGLOU,numstate+1));
                    input->AddArc(numstate,fst::StdArc(code,IEPSILON,PENALEPSILON,numstate+1));
                }
                numstate++;
            }
        }
    } else if(slu->words->size() > 0) { // add support for memorized words
        for(size_t i = 0; i < slu->words->size(); i++) {
            char* pt = (*slu->words)[i];
            if (!strncmp(pt,"**",2)) { uncertain=True; deca=2; t_outword[nb].select=3; } else { uncertain=False; deca=0; t_outword[nb].select=1; }
            if (word2code(lexidword,pt+deca,&code))
            {
                t_outword[nb++].index=code; if (nb==MAX_FIELD) ERREUR("cste MAX_FIELD too small","");
                input->AddState();
                input->AddArc(numstate,fst::StdArc(code,code,0,numstate+1));
                if (uncertain)
                {
                    input->AddArc(numstate,fst::StdArc(code,IGLOU,PENALGLOU,numstate+1));
                    input->AddArc(numstate,fst::StdArc(code,IEPSILON,PENALEPSILON,numstate+1));
                }
                numstate++;
            }
        }
    }
    /* now the new words */
    for(i = 0; i < num_words; i++) {
        if (word2code(lexidword,words[i],&code))
        {
            t_outword[nb].select=3;
            t_outword[nb++].index=code; if (nb==MAX_FIELD) ERREUR("cste MAX_FIELD too small","");
            input->AddState();
            input->AddArc(numstate,fst::StdArc(code,IEPSILON,PENALEPSILON,numstate+1));
            input->AddArc(numstate,fst::StdArc(code,IGLOU,PENALGLOU,numstate+1));
            input->AddArc(numstate,fst::StdArc(code,code,0,numstate+1));
            numstate++;
        }
        //else fprintf(stderr,"Warning: [%s] is unknown, discared\n",words[i]);
    }
    input->SetFinal(numstate,0);
    *nbword=nb;
    return input;
}

void reset_slu(slu_t* slu) {
    // reset actions
    for(size_t i = 0; i < slu->actions->size(); i++) free((*slu->actions)[i]);
    slu->actions->clear();
    // reset words
    for(size_t i = 0; i < slu->words->size(); i++) free((*slu->words)[i]);
    slu->words->clear();

}

char* run_process(slu_t* slu, char** words, int num_words, int prevn, char *prevword)
{
    fst::StdVectorFst *input,result1,result2,result3;
    char *ch;
    static int *tocc,i,j,nb,nbac,nbword;
    type_outword *t_outword;
    t_outword=(type_outword *)malloc(sizeof(type_outword)*MAX_FIELD);
    tocc=(int*)malloc(sizeof(int)*(max_code_lexicon(slu->lexidaction)+1));
    for(i=0;i<max_code_lexicon(slu->lexidaction);i++) tocc[i]=0;

    input=build_fst_words(slu, words, num_words, slu->lexidword,prevword,t_outword,&nbword);
    //fprintf(stderr, "size of input = %d\n", input->NumStates());
    fst::ArcSort(input, fst::StdOLabelCompare());
    fst::Compose(*input, *(slu->fstClean), &result1);
    //fprintf(stderr, "size of result1 = %d\n", result1.NumStates());
    fst::ArcSort(&result1, fst::StdOLabelCompare());
    //fprintf(stderr, "size of model = %d\n", slu->fstModel->NumStates());
    fst::Compose(result1,*(slu->fstModel),&result2);
    //fprintf(stderr, "size of result2 = %d\n", result2.NumStates());
    fst::ShortestPath(result2,&result3,1,false);
    fst::TopSort(&result3);
    delete input;

    // reset actions
    for(size_t i = 0; i < slu->actions->size(); i++) {
        free((*slu->actions)[i]);
    }
    slu->actions->clear();

    // process nbest
    fst::StdVectorFst::StateId start = result3.Start();
    if ((int)start>=0)
    {
        for(fst::ArcIterator<fst::StdVectorFst> aiter(result3, start);!aiter.Done(); aiter.Next())
        {
            const fst::StdArc arc = aiter.Value(); // this arc is an epsilon arc leading to the next path
            nbword=0;
            if (arc.ilabel>0)
            {
                t_outword[nbword].index=arc.ilabel;
                if ((arc.olabel==0)&&(arc.weight!=0)) t_outword[nbword].select=0; else
                    if (arc.weight==0) t_outword[nbword].select=1; else t_outword[nbword].select=2;
                nbword++;
            }
            if (arc.olabel>=2)
            {
                nb=nbac=1;
                if (!code2word(slu->lexidaction,arc.olabel,&ch)) ERREURd("unknown action code:",arc.olabel);
                tocc[arc.olabel]++;
                if (nbac>prevn) {
                    slu->actions->push_back(strdup(ch));
                    //printf("%s[%d]=>%d\n",ch,tocc[arc.olabel],nbac);
                }
            }
            int64 state = arc.nextstate;
            while(result3.Final(state) == fst::StdArc::Weight::Zero())
            {
                const fst::StdArc nextArc = fst::ArcIterator<fst::StdVectorFst>(result3, state).Value();
                nb++;
                if (nextArc.ilabel>0)
                {
                    t_outword[nbword].index=nextArc.ilabel;
                    if ((nextArc.olabel==0)&&(nextArc.weight!=0)) t_outword[nbword].select=0; else
                        if (nextArc.weight==0) t_outword[nbword].select=1; else t_outword[nbword].select=2;
                    nbword++; if (nbword==MAX_FIELD) ERREUR("cste MAX_SIZE_MESG too small","");
                }
                if (nextArc.olabel>=2)
                {
                    nbac=nb;
                    if (!code2word(slu->lexidaction,nextArc.olabel,&ch)) ERREURd("unknown action code:",nextArc.olabel);
                    tocc[nextArc.olabel]++;
                    if (nbac>prevn) {
                        slu->actions->push_back(strdup(ch));
                        //printf("%s[%d]=>%d\n",ch,tocc[nextArc.olabel],nbac);
                    }
                }
                state = nextArc.nextstate;
            }
        }
    }


    //printf("STRING:");
    for(size_t i = 0; i < slu->words->size(); i++) free((*slu->words)[i]);
    slu->words->clear();

    for(j=nbword-1;(j>0)&&(t_outword[j].select==0);j--) t_outword[j].select=3;
    for(i=0;i<=j;i++) if (t_outword[i].select!=0)
    {
        if (!code2word(slu->lexidword,t_outword[i].index,&ch)) ERREURd("unknown word code:",t_outword[i].index);
        if (t_outword[i].select==2) {
            //printf(" %s",CHglouton); 
            slu->words->push_back(strdup(CHglouton));
        } else if (t_outword[i].select==3) {
            //printf(" **%s",ch); 
            char buffer[strlen(ch) + 3];
            buffer[0] = buffer[1] = '*';
            strcpy(buffer + 2, ch);
            slu->words->push_back(strdup(buffer));
        } else {
            //printf(" %s",ch);
            slu->words->push_back(strdup(ch));
        }
    }
    /* now we limit to a window of WINLENGTH unmatched words */
    if (nbword-WINLENGTH>j) j=nbword-WINLENGTH; else j++;
    for(;j<nbword;j++)
    {
        if (!code2word(slu->lexidword,t_outword[j].index,&ch)) ERREURd("unknown word code:",t_outword[j].index);
        char buffer[strlen(ch) + 3];
        buffer[0] = buffer[1] = '*';
        strcpy(buffer + 2, ch);
        slu->words->push_back(strdup(buffer));
        //printf(" **%s",ch);
    }
    free(tocc); free(t_outword);

    // make output string
    std::string output;
    for(size_t i = 0; i < slu->words->size(); i++) {
        if(i > 0) output += " ";
        output += (*slu->words)[i];
    }
    return strdup(output.c_str());
}

/*................................................................*/

slu_t* init_slu(char* chfileword, char* chfileaction, char* chfilemodel, char* chfileclean) {
    slu_t* slu = (slu_t*) malloc(sizeof(slu_t));

    slu->lexidword = load_lexicon(chfileword);
    //fprintf(stderr, "lexidword = %d\n", slu->lexidword);
    slu->lexidaction = load_lexicon(chfileaction);
    //fprintf(stderr, "lexidaction = %d\n", slu->lexidaction);
    slu->fstModel = fst::StdVectorFst::Read(chfilemodel);
    slu->fstClean = fst::StdVectorFst::Read(chfileclean);
    slu->actions = new std::vector<char*>();
    slu->words = new std::vector<char*>();
    return slu;
}

void free_slu(slu_t* slu) {
    delete slu->fstClean;
    delete slu->fstModel;
    for(size_t i = 0; i < slu->actions->size(); i++) free((*slu->actions)[i]);
    delete slu->actions;
    for(size_t i = 0; i < slu->words->size(); i++) free((*slu->words)[i]);
    delete slu->words;
    free(slu);
}

int num_actions(slu_t* slu) {
    return slu->actions->size();
}

char* get_action(slu_t* slu, size_t index) {
    if(index >= 0 && index < slu->actions->size()) return (*slu->actions)[index];
    return NULL;
}

char* get_actions(slu_t* slu) {
    std::string output;
    for(size_t i = 0; i < slu->actions->size(); i++) {
        if(i > 0) output += " ";
        output += (*slu->actions)[i];
    }
    //fprintf(stderr, "get_actions() = {%s}\n", output.c_str());
    return strdup(output.c_str());
}

char* run_slu(slu_t* slu, char** words, int num_words, int prevn, char *prevword) {
    //fprintf(stderr, "before process\n");
    return run_process(slu, words, num_words, prevn,prevword);
    //fprintf(stderr, "after process\n");
}

}