#include<stdio.h>
#include<stdlib.h>
#include<string.h>

#include "trie.h"
#include "util.h"

trie_state *trie_state_new(trie_trans *transitions, int is_accept)
{
  trie_state *state = memalloc(sizeof(trie_state));
  state->transitions = transitions;
  state->is_accept = is_accept;
  state->fail = 0;
  return state;
}

void trie_state_free(trie_state *state)
{
  if(state){
    trie_trans_free_rec(state->transitions);
    free(state);
  }
}

trie *trie_new(void)
{
  trie *t = memalloc(sizeof(trie));
  t->states = NULL;
  t->size = 0;
  t->states_nb = 0;
  trie_add_state(t); /* initial state */
  return t;
}

void trie_free(trie *t)
{
  int i;
  if(t){
    for(i=0; i < t->size; i++)
      trie_state_free(t->states[i]);
    free(t->states);
    free(t);
  }
}

trie_trans *trie_trans_new(int destination, int symbol, trie_trans *next)
{
  trie_trans *trans = memalloc(sizeof(trie_trans));
  trans->destination = destination;
  trans->symbol = symbol;
  trans->next = next;
  return trans;
}

void trie_trans_free_rec(trie_trans *trans)
{
  if(trans){
    if(trans->next)
      trie_trans_free_rec(trans->next);
    else
      free(trans);
  }
}

int trie_add_state(trie *t)
{
  if(t->states_nb == t->size){
    t->size = 2 * (t->size + 1);
    t->states = (trie_state **) realloc(t->states, t->size * sizeof(trie_state *));
  }
  t->states[t->states_nb] = trie_state_new(NULL, 0);
  t->states_nb++;
  return t->states_nb - 1;
}

void trie_add_trans(trie *t, int origin, int symbol, int destination)
{
  /* make sure origin state exists */
  while(origin >= t->states_nb)
    trie_add_state(t);

  /* make sure destination state exists */
  while(destination >= t->states_nb)
    trie_add_state(t);
  
  t->states[origin]->transitions = trie_trans_new(destination, symbol, t->states[origin]->transitions);  
}

void trie_add_word(trie *t, int *word, int length)
{
  int current_index = 0;
  int current_state = 0;
  trie_trans *current_trans = NULL;
  int transition_exists = 1;
  int destination;
  
  while((current_index < length) && transition_exists){
    transition_exists = 0;
    for(current_trans = t->states[current_state]->transitions; current_trans; current_trans = current_trans->next){
      if(current_trans->symbol == word[current_index]){
	current_state = current_trans->destination;
	current_index++;
	transition_exists = 1;
	break;
      }
    }
  }
  while(current_index < length){
    destination = trie_add_state(t);
    trie_add_trans(t, current_state, word[current_index], destination);
    current_index++;
    current_state = destination;
  }
  t->states[current_state]->is_accept = 1;
}

void trie_print(FILE *f, trie *t)
{
  int i;
  trie_trans *trans;
  for(i=0; i < t->states_nb; i++){
    fprintf(f, "state %d", i);
    if(t->states[i]->is_accept) fprintf(f, " ACCEPT\n");
    else fprintf(f, "\n");
    fprintf(f, "FAIL = %d\n", t->states[i]->fail);
    for(trans = t->states[i]->transitions; trans; trans = trans->next){
      fprintf(f, "%d %d %d\n", i, trans->symbol, trans->destination);
    }
    fprintf(f, "\n");
  }
}

int trie_lookup(trie *t, int *word, int length)
{
  int i;
  int current_state = 0;
  trie_trans *trans;
  for(i=0; i < length;  i++){
    for(trans = t->states[current_state]->transitions; trans; trans = trans->next){
      if(trans->symbol == word[i]){
	current_state = trans->destination;
	break;
      }
    }
    if(trans == NULL)
      return 0;
  }
  return t->states[current_state]->is_accept;
}


trie *trie_build_from_collection(char *filename)
{
  trie *t = trie_new();
  FILE *f = myfopen(filename, "r");
  char buffer[1000];
  int word[100];
  int length;
  char *token;

  while(fgets(buffer, 10000, f)){
    length = 0;
    token = strtok(buffer, " ");
    while(token){
      word[length++] = atoi(token);
      token = strtok(NULL, " ");
    }
    trie_add_word(t, word, length);
  }
  fclose(f);
  return t;
}

int trie_destination_state(trie *t, int origin, int symbol)
{
  trie_trans *trans;
  for(trans = t->states[origin]->transitions; trans; trans = trans->next){
    if(trans->symbol == symbol)
      return trans->destination;
  }
  return 0;
}