Training networks

In this section we will learn how to feed mini-batches into a network for training or inference. Let us assume we have some Keras model of a classification network

model = Sequential()
model.add(Convolution2D(32, (3, 3), input_shape=INPUT_SHAPE))
...
model.add(Dense(NUM_CLASSES))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

and let us further assume we have a pipeline that generates mini-batches as described in the previous section

batches = train_samples >> read_image >> ... >> build_batch

we then could train the model (for a single epoch) using the train_on_batch method provided by Keras

for batch in batches:
    model.train_on_batch(*batch)

or a bit more explicitly

for inputs, outputs in batches:
    model.train_on_batch(inputs, outputs)

Note that batches is a generator and not a list of batches – there is no consumer such as Consume or Collect() at the end of the pipeline. Also we have to ensure that the shape of the batches matches the INPUT_SHAPE of the model – a common problem. Use PrintType() to print the shape of the generated batches.

Keras supports another method for training, fit_generator, which expects an infinite stream of mini-batches. This can easily be achieved by adding a Cycle nut after the loading of the training samples:

batches = train_samples >> Cycle() >> read_image >> ... >> build_batch
model.fit_generator(batches)

However, the easiest way to train a Keras network is to take advantage of the KerasNetwork wrapper provided by nuts-ml. It takes a Keras model and wraps it into a nut that can directly be plugged into a pipeline:

network = KerasNetwork(model)
train_samples >> read_image >> ... >> build_batch >> network.train() >> Consume()

Note that we need a consume at the end of the pipeline to pull the data. In the examples above, train_on_batch and fit_generator were the consumers. network.train() trains the network and emits the loss and any specified metric (e.g. accuracy in this example) per mini-batch. We can collect this output and report average loss and accuracy per epoch.

network = KerasNetwork(model)
for epoch in range(EPOCHS):
    t_loss, t_acc = train_samples >> ... >> build_batch >> network.train() >> Unzip()
    print("train loss  :", t_loss >> Mean())
    print("train acc   :", t_acc >> Mean())

Apart from the training loss (and accuracy) we often want to know the networks performance on a validation set. The data preprocessing pipelines in both cases are very similar but typically we do not augment when validating. In the following, a code sketch for training and validation:

network = KerasNetwork(model)
for epoch in range(EPOCHS):
    t_loss, t_acc = (train_samples >> read_image >> transform >> augment >>
                     Shuffle(100) >> build_batch >> network.train() >> Unzip())
    print("train loss  :", t_loss >> Mean())
    print("train acc   :", t_acc >> Mean())

    v_loss, v_acc = (val_samples >> read_image >> transform >>
                     build_batch >> network.validate() >> Unzip())
    print("val loss  :", v_loss >> Mean())
    print("val acc   :", v_acc >> Mean())

Note that we skip the augmentation and shuffling that are part of the training pipeline when validating.

Training and validation performance are averaged over batches. The true performance, however, needs to be computed on a per-sample bases. nuts-ml provides evaluate() for this purpose. For instance, the code sketch below calls network.evaluate() to compute the categorical_accuracy over all test samples

e_acc = (test_samples >> read_image >> transform >> build_batch >>
         network.evaluate([categorical_accuracy])
print("evaluation acc  :", e_acc)

This code typically would run after the epoch loop when the network training is complete. Note that evaluate is a sink (no Collect needed) and returns a single number per metric (no averaging required).

Finally, once we trained the network and are happy with the classification accuracy we would like to use the network for inference/prediction. Prediction is different from training, validation and evaluation in that we don’t know the target/output values – those we want to infer. Consequently, the mini-batches need to be constructed without outputs and then can be feed into the predict() function, that returns the softmax vectors:

build_pred_batch = BuildBatch(BATCH_SIZE).input(...)

predictions = (samples >> read_image >> transform >> build_pred_batch >>
               network.predict() >> Map(ArgMax()) >> Collect())

We use Map(ArgMax()) to retrieve the class index of the class with the highest softmax probability and collect those indices as network predictions. Note that we easily could convert the class indices to labels using ConvertLabel.