ruk·si

Keras
Mid-training Callbacks

Updated at 2018-06-22 14:37

You can define functions that are run in the middle of your training e.g. for logging. model.fit(..., verbose=0) disables the default logging.

import json

import keras
import numpy as np
from keras.layers import Dense, Dropout
from keras.losses import categorical_crossentropy
from keras.models import Sequential
from keras.optimizers import SGD

x_train = np.random.random((1000, 20))
labels_train = np.random.randint(10, size=(1000, 1))
y_train = keras.utils.to_categorical(labels_train, num_classes=10)

x_test = np.random.random((100, 20))
labels_test = np.random.randint(10, size=(100, 1))
y_test = keras.utils.to_categorical(labels_test, num_classes=10)

model = Sequential()
model.add(Dense(64, activation='relu', input_dim=20))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

sgd = SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss=categorical_crossentropy, optimizer=sgd)

# normally you wound't define all of these, just the ones you would use,
# which would mean on_epoch_end in this case
class JsonLogging(keras.callbacks.Callback):

	def set_params(self, params):
		super().set_params(params)

	def set_model(self, model):
		super().set_model(model)

	def on_epoch_begin(self, epoch, logs=None):
		pass

	def on_epoch_end(self, epoch, logs=None):
		print(json.dumps({'epoch': epoch, 'loss': logs["loss"]}))

	def on_batch_begin(self, batch, logs=None):
		pass

	def on_batch_end(self, batch, logs=None):
		pass

	def on_train_begin(self, logs=None):
		pass

	def on_train_end(self, logs=None):
		pass

model.fit(
	x_train,
	y_train,
	epochs=20,
	batch_size=128,
	verbose=0,
	callbacks=[JsonLogging()]
)
# => {"epoch": 0, "loss": 2.3880539054870606} etc.