Select Git revision
keras_fc_cnn_mnist.py
Luc Giffon authored
keras_fc_cnn_mnist.py 2.91 KiB
import numpy as np
from keras import optimizers
from keras.models import Sequential
from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D, Dropout
from keras.callbacks import LearningRateScheduler, TensorBoard
from keras.preprocessing.image import ImageDataGenerator
from keras.regularizers import l2
import skluc.main.data.mldatasets as dataset
batch_size = 128
epochs = 200
iterations = 391
weight_decay = 0.0001
log_filepath = './lenet_dp_da_wd'
def build_model():
model = Sequential()
model.add(Conv2D(32, (5, 5), padding='valid', activation='relu', kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay), input_shape=(28, 28, 1), name="conv_1"))
model.add(MaxPooling2D((2, 2), strides=(2, 2), name="conv_pool_1"))
model.add(Conv2D(64, (5, 5), padding='valid', activation='relu', kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay), name="conv_2"))
model.add(MaxPooling2D((2, 2), strides=(2, 2), name="conv_pool_2"))
model.add(Flatten(name='conv_flatten'))
model.add(Dense(1024, activation='relu', kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay)))
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay)))
sgd = optimizers.SGD(lr=.001, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
return model
def scheduler(epoch):
if epoch <= 60:
return 0.05
if epoch <= 120:
return 0.01
if epoch <= 160:
return 0.002
return 0.0004
if __name__ == '__main__':
# load data
validation_size = 10000
seed = 0
data = dataset.MnistDataset(validation_size=validation_size, seed=seed)
data.load()
data.normalize()
data.to_one_hot()
data.to_image()
data.data_astype(np.float32)
data.labels_astype(np.float32)
(x_train, y_train), (x_test, y_test) = data.train, data.test
# build network
model = build_model()
print(model.summary())
# set callback
tb_cb = TensorBoard(log_dir=log_filepath, histogram_freq=0)
change_lr = LearningRateScheduler(scheduler)
cbks = [change_lr, tb_cb]
# using real-time data augmentation
print('Using real-time data augmentation.')
datagen = ImageDataGenerator(horizontal_flip=True,
width_shift_range=0.125, height_shift_range=0.125, fill_mode='constant', cval=0.)
datagen.fit(x_train)
# start traing
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
steps_per_epoch=iterations,
epochs=epochs,
callbacks=cbks,
validation_data=(x_test, y_test))
# save model
model.save('lenet.h5')
print("Final evaluation on test set: {}".format(model.evaluate(x_test, y_test)))