Training networks¶
In this section we will learn how to feed mini-batches into a network for training or inference. Let us assume we have some Keras model of a classification network
model = Sequential()
model.add(Convolution2D(32, (3, 3), input_shape=INPUT_SHAPE))
...
model.add(Dense(NUM_CLASSES))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
and let us further assume we have a pipeline that generates mini-batches as described in the previous section
batches = train_samples >> read_image >> ... >> build_batch
we then could train the model (for a single epoch) using the
train_on_batch
method provided by Keras
for batch in batches:
model.train_on_batch(*batch)
or a bit more explicitly
for inputs, outputs in batches:
model.train_on_batch(inputs, outputs)
Note that batches
is a generator and not a list of batches – there is no
consumer such as Consume
or Collect()
at the end of the pipeline.
Also we have to ensure that the shape of the batches matches the INPUT_SHAPE
of the model – a common problem. Use PrintType()
to print the shape of the
generated batches.
Keras supports another method for training, fit_generator
, which expects an
infinite stream of mini-batches. This can easily be achieved by adding a Cycle
nut after the loading of the training samples:
batches = train_samples >> Cycle() >> read_image >> ... >> build_batch
model.fit_generator(batches)
However, the easiest way to train a Keras network is to take advantage of the
KerasNetwork
wrapper provided by nuts-ml. It takes a Keras model and
wraps it into a nut that can directly be plugged into a pipeline:
network = KerasNetwork(model)
train_samples >> read_image >> ... >> build_batch >> network.train() >> Consume()
Note that we need a consume at the end of the pipeline to pull the data. In the examples
above, train_on_batch
and fit_generator
were the consumers.
network.train()
trains the network and emits the loss and any specified metric
(e.g. accuracy in this example) per mini-batch. We can collect this output and
report average loss and accuracy per epoch.
network = KerasNetwork(model)
for epoch in range(EPOCHS):
t_loss, t_acc = train_samples >> ... >> build_batch >> network.train() >> Unzip()
print("train loss :", t_loss >> Mean())
print("train acc :", t_acc >> Mean())
Apart from the training loss (and accuracy) we often want to know the networks performance on a validation set. The data preprocessing pipelines in both cases are very similar but typically we do not augment when validating. In the following, a code sketch for training and validation:
network = KerasNetwork(model)
for epoch in range(EPOCHS):
t_loss, t_acc = (train_samples >> read_image >> transform >> augment >>
Shuffle(100) >> build_batch >> network.train() >> Unzip())
print("train loss :", t_loss >> Mean())
print("train acc :", t_acc >> Mean())
v_loss, v_acc = (val_samples >> read_image >> transform >>
build_batch >> network.validate() >> Unzip())
print("val loss :", v_loss >> Mean())
print("val acc :", v_acc >> Mean())
Note that we skip the augmentation and shuffling that are part of the training pipeline when validating.
Training and validation performance are averaged over batches. The true performance,
however, needs to be computed on a per-sample bases. nuts-ml provides evaluate()
for this purpose. For instance, the code sketch below calls network.evaluate()
to compute the categorical_accuracy
over all test samples
e_acc = (test_samples >> read_image >> transform >> build_batch >>
network.evaluate([categorical_accuracy])
print("evaluation acc :", e_acc)
This code typically would run after the epoch loop when the network training is complete.
Note that evaluate
is a sink (no Collect
needed) and returns a single number per metric (no averaging required).
Finally, once we trained the network and are happy with the classification accuracy
we would like to use the network for inference/prediction. Prediction is different
from training, validation and evaluation in that we don’t know the target/output values
– those we want to infer. Consequently, the mini-batches need to be constructed
without outputs and then can be feed into the predict()
function, that returns
the softmax vectors:
build_pred_batch = BuildBatch(BATCH_SIZE).input(...)
predictions = (samples >> read_image >> transform >> build_pred_batch >>
network.predict() >> Map(ArgMax()) >> Collect())
We use Map(ArgMax())
to retrieve the class index of the class with the highest
softmax probability and collect those indices as network predictions. Note that we
easily could convert the class indices to labels using ConvertLabel
.