Select Git revision
dnn_train.c
dnn_train.c 3.20 KiB
#include <stdio.h>
#include "fann.h"
#include "context.h"
#include "cf_file.h"
int FANN_API test_callback(struct fann *ann, struct fann_train_data *train,
unsigned int max_epochs, unsigned int epochs_between_reports,
float desired_error, unsigned int epochs)
{
printf("Epochs %8d. MSE: %.5f. Desired-MSE: %.5f\n", epochs, fann_get_MSE(ann), desired_error);
return 0;
}
void dnn_train_help_message(context *ctx)
{
context_general_help_message(ctx);
context_iterations_help_message(ctx);
context_sent_nb_help_message(ctx);
context_hidden_neurons_nb_help_message(ctx);
fprintf(stderr, "INPUT\n");
context_fann_help_message(ctx);
fprintf(stderr, "OUTPUT\n");
context_dnn_model_help_message(ctx);
}
void dnn_train_check_options(context *ctx)
{
if(!ctx->fann_filename
|| !ctx->dnn_model_filename
|| ctx->help
){
dnn_train_help_message(ctx);
exit(1);
}
}
int main(int argc, char *argv[])
{
/* fann_type *calc_out; */
int input_size;
int output_size;
const unsigned int num_layers = 3;
unsigned int num_neurons_hidden;
const float desired_error = (const float) 0;
unsigned int max_epochs = 200;
const unsigned int epochs_between_reports = 10;
struct fann *ann;
struct fann_train_data *data;
/* unsigned int i = 0; */
/* unsigned int j = 0; */
/* unsigned int argmax, ref; */
/* double max; */
/* unsigned int correct = 0; */
context *ctx;
ctx = context_read_options(argc, argv);
dnn_train_check_options(ctx);
max_epochs = ctx->iteration_nb;
num_neurons_hidden = ctx->hidden_neurons_nb;
data = fann_read_train_from_file(ctx->fann_filename);
input_size = fann_num_input_train_data(data);
output_size = fann_num_output_train_data(data);
printf("Creating network.\n");
ann = fann_create_standard(num_layers, input_size, num_neurons_hidden, output_size);
fann_set_activation_steepness_hidden(ann, 1);
fann_set_activation_steepness_output(ann, 1);
fann_set_activation_function_hidden(ann, FANN_SIGMOID_SYMMETRIC);
/* fann_set_activation_function_output(ann, FANN_SIGMOID_SYMMETRIC); */
fann_set_train_stop_function(ann, FANN_STOPFUNC_BIT);
fann_set_bit_fail_limit(ann, 0.01f);
fann_set_training_algorithm(ann, FANN_TRAIN_RPROP);
printf("Training network.\n");
fann_init_weights(ann, data);
fann_train_on_data(ann, data, max_epochs, epochs_between_reports, desired_error);
printf("Saving network.\n");
fann_save(ann, ctx->dnn_model_filename);
printf("Cleaning up.\n");
fann_destroy_train(data);
fann_destroy(ann);
return 0;
}
/*
printf("Testing network. %f\n", fann_test_data(ann, data));
for(i = 0; i < fann_length_train_data(data); i++){
calc_out = fann_run(ann, data->input[i]);
argmax = 0;
max = calc_out[argmax];
ref = 0;
for(j=0; j < output_size; j++){
if(data->output[i][j] == 1) ref = j;
if(calc_out[j] > max){
max = calc_out[j];
argmax = j;
}
}
if(argmax == ref) correct++;
}
printf("precision = %f\n", (float) correct /fann_length_train_data(data));*/
/*printf("XOR test (%f,%f) -> %f, should be %f, difference=%f\n",
data->input[i][0], data->input[i][1], calc_out[0], data->output[i][0],
fann_abs(calc_out[0] - data->output[i][0]));*/