Skip to content
Snippets Groups Projects
Select Git revision
  • 248bfbf7e30c8d2e2619fdd331173c224794e98a
  • main default protected
2 results

test_cli.py

Blame
  • dnn_train.c 3.20 KiB
    #include <stdio.h>
    #include "fann.h"
    #include "context.h"
    #include "cf_file.h"
    
    int FANN_API test_callback(struct fann *ann, struct fann_train_data *train,
    	unsigned int max_epochs, unsigned int epochs_between_reports, 
    	float desired_error, unsigned int epochs)
    {
    	printf("Epochs     %8d. MSE: %.5f. Desired-MSE: %.5f\n", epochs, fann_get_MSE(ann), desired_error);
    	return 0;
    }
    
    void dnn_train_help_message(context *ctx)
    {
      context_general_help_message(ctx);
      context_iterations_help_message(ctx);
      context_sent_nb_help_message(ctx);
      context_hidden_neurons_nb_help_message(ctx);
      fprintf(stderr, "INPUT\n");
      context_fann_help_message(ctx);
      fprintf(stderr, "OUTPUT\n");
      context_dnn_model_help_message(ctx);
    }
    
    
    void dnn_train_check_options(context *ctx)
    {
      if(!ctx->fann_filename
         || !ctx->dnn_model_filename
         || ctx->help
         ){
        dnn_train_help_message(ctx);
        exit(1);
      }
    }
    
    
    int main(int argc, char *argv[])
    {
      /* fann_type *calc_out; */
      int input_size;
      int output_size;
      const unsigned int num_layers = 3;
      unsigned int num_neurons_hidden;
      const float desired_error = (const float) 0;
      unsigned int max_epochs = 200;
      const unsigned int epochs_between_reports = 10;
      struct fann *ann;
      struct fann_train_data *data;
      /* unsigned int i = 0; */
      /* unsigned int j = 0; */
      /* unsigned int argmax, ref; */
      /* double max; */
      /* unsigned int correct = 0; */
      context *ctx;
    
      ctx = context_read_options(argc, argv);
      dnn_train_check_options(ctx);
    
      max_epochs = ctx->iteration_nb;
      num_neurons_hidden = ctx->hidden_neurons_nb;
    
      data = fann_read_train_from_file(ctx->fann_filename);
      input_size = fann_num_input_train_data(data);
      output_size = fann_num_output_train_data(data);
      
      printf("Creating network.\n");
      ann = fann_create_standard(num_layers, input_size, num_neurons_hidden, output_size);
      fann_set_activation_steepness_hidden(ann, 1);
      fann_set_activation_steepness_output(ann, 1);
      fann_set_activation_function_hidden(ann, FANN_SIGMOID_SYMMETRIC);
      /* fann_set_activation_function_output(ann, FANN_SIGMOID_SYMMETRIC); */
      fann_set_train_stop_function(ann, FANN_STOPFUNC_BIT);
      fann_set_bit_fail_limit(ann, 0.01f);
      fann_set_training_algorithm(ann, FANN_TRAIN_RPROP);
      
      printf("Training network.\n");
      
      fann_init_weights(ann, data);
      fann_train_on_data(ann, data, max_epochs, epochs_between_reports, desired_error);
      
      printf("Saving network.\n");
      fann_save(ann, ctx->dnn_model_filename);
      
      printf("Cleaning up.\n");
      fann_destroy_train(data);
      fann_destroy(ann);
      
      return 0;
    }
    
    
    /*
      printf("Testing network. %f\n", fann_test_data(ann, data));
      
      for(i = 0; i < fann_length_train_data(data); i++){
        calc_out = fann_run(ann, data->input[i]);
        argmax = 0;
        max = calc_out[argmax];
        ref = 0;
        for(j=0; j < output_size; j++){
          if(data->output[i][j] == 1) ref = j;
          if(calc_out[j] > max){
    	max = calc_out[j];
    	argmax = j;
          }
        }
        if(argmax == ref) correct++;
        
    
      }
      printf("precision = %f\n", (float) correct /fann_length_train_data(data));*/
    
        /*printf("XOR test (%f,%f) -> %f, should be %f, difference=%f\n",
          data->input[i][0], data->input[i][1], calc_out[0], data->output[i][0],
          fann_abs(calc_out[0] - data->output[i][0]));*/