Keras - Mid-training Callbacks
Updated at 2018-06-22 14:37
You can define functions that are run in the middle of your training e.g. for logging. model.fit(..., verbose=0)
disables the default logging.
import json
import keras
import numpy as np
from keras.layers import Dense, Dropout
from keras.losses import categorical_crossentropy
from keras.models import Sequential
from keras.optimizers import SGD
x_train = np.random.random((1000, 20))
labels_train = np.random.randint(10, size=(1000, 1))
y_train = keras.utils.to_categorical(labels_train, num_classes=10)
x_test = np.random.random((100, 20))
labels_test = np.random.randint(10, size=(100, 1))
y_test = keras.utils.to_categorical(labels_test, num_classes=10)
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=20))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
sgd = SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss=categorical_crossentropy, optimizer=sgd)
# normally you wound't define all of these, just the ones you would use,
# which would mean on_epoch_end in this case
class JsonLogging(keras.callbacks.Callback):
def set_params(self, params):
super().set_params(params)
def set_model(self, model):
super().set_model(model)
def on_epoch_begin(self, epoch, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
print(json.dumps({'epoch': epoch, 'loss': logs["loss"]}))
def on_batch_begin(self, batch, logs=None):
pass
def on_batch_end(self, batch, logs=None):
pass
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
model.fit(
x_train,
y_train,
epochs=20,
batch_size=128,
verbose=0,
callbacks=[JsonLogging()]
)
# => {"epoch": 0, "loss": 2.3880539054870606} etc.