Skip to content

Training setup

We keep the same training setup as in the patch-based classification section.

Question

Create a file named part_3_train.py in which you will implement the following.

Metrics

As addressed in the previous sections, our metrics are Precision, Recall, and F1-Score (that we have already defined in mymetrics.py).

        cls(class_id=class_id)
        for class_id in range(class_nb)
        for cls in [keras.metrics.Precision, keras.metrics.Recall]
    ]

    # F1-Score for each class
    metrics += [
        FScore(class_id=class_id, name=f"fscore_cls{class_id}") for class_id in range(class_nb)
    ]

    model.compile(
        loss={tgt_key: keras.losses.CategoricalCrossentropy()},

Callbacks

The different callbacks to be used during training.

ModelCheckpoint

We let keras compute the metrics over validation dataset after each epoch, and we save only the model with the lowest validation loss.

    callbacks = []
    callbacks.append(
        keras.callbacks.ModelCheckpoint(
            params.model, mode="min", save_best_only=True, monitor="val_loss"
        )
    )

TensorBoard

We keep a trace of metrics for TensorBoard using the keras built-in tf.keras.callbacks.TensorBoard class:

    if params.log_dir:
        callbacks.append(keras.callbacks.TensorBoard(log_dir=params.log_dir))

This is optional, and if we don't provide the params.log_dir (from the command line argument parser --ckpt_log argument), it won't be used.

BackupAndRestore

Since the training of our model can be a bit long on CPU, we use a callback to save regularly the model weights, so that the training can be interrupted and restored at any time (e.g. hardware failure, or process killed).

Info

The backup-and-restore callback deletes the checkpoint once the training is completed, i.e. after the specified number of epochs in fit() has been performed.

    if params.ckpt_dir:
        callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=params.ckpt_dir))

This is optional, and if we don't provide the params.ckpt_dir (from the command line argument parser --ckpt_dir argument), it won't be used.

Warning

When the --ckpt_dir argument is provided in the CLI, the training will always used the weights saved in the provided directory. Do not forget to delete the directory when you want to restart a training from scratch after a previously interrupted training.

Question

  • Implement the backup-and-restore callback,
  • Check that the callback is working during training in interrupting the process with ctrl+c keys, and restarts at the right epoch with the same metrics.

Final evaluation

Finally, once the model is properly trained (i.e. saved when the selected metric has reach a new minimum), we compute the metrics over the test dataset, so that we have a proper evaluation of its performances on an independent dataset.

    # Final evaluation on the test dataset
    model.load_weights(params.model)
    values = model.evaluate(ds_test, batch_size=params.batch_size, return_dict=True)
    print("\nMetrics over test dataset:")
    for name, value in values.items():
        print(f"{name}\t{value}")

Question

  • Append this code after model.fit(...), at the end of the training.

You can now run the training using --ckpt_dir to select a directory to write the checkpoints. Also you can change the --log_dir to select another directory for the TensorBoard logs. Use --epochs to select the number of epochs.

python part_3_train.py \
  --model /data/models/model3.keras \
  --log_dir /data/logs/model3 \
  --epochs 50 \
  --ckpt_dir /data/ckpts/model3

Question

  • Compare the metrics over the validation dataset, the test dataset and the training dataset.
Solution
part_3_train.py
"""Semantic segmentation of Spot-7 images"""

from mymetrics import FScore
import otbtf
import keras
import argparse


class_nb = 4  # number of classes
inp_key_p = "input_p"  # model input p
inp_key_xs = "input_xs"  # model input xs
tgt_key = "estimated"  # model target


def create_otbtf_dataset(p, xs, labels):
    return otbtf.DatasetFromPatchesImages(filenames_dict={"p": p, "xs": xs, "labels": labels})


def dataset_preprocessing_fn(sample):
    return {
        inp_key_p: sample["p"],
        inp_key_xs: sample["xs"],
        tgt_key: otbtf.ops.one_hot(labels=sample["labels"], nb_classes=class_nb),
    }


def create_dataset(p, xs, labels, batch_size=8):
    otbtf_dataset = create_otbtf_dataset(p, xs, labels)
    return otbtf_dataset.get_tf_dataset(
        batch_size=batch_size,
        preprocessing_fn=dataset_preprocessing_fn,
        targets_keys=[tgt_key],
    )


def conv(inp, depth, name, strides=2):
    conv_op = keras.layers.Conv2D(
        filters=depth,
        kernel_size=3,
        strides=strides,
        activation="relu",
        padding="same",
        name=name,
    )
    return conv_op(inp)


def tconv(inp, depth, name, activation="relu"):
    tconv_op = keras.layers.Conv2DTranspose(
        filters=depth,
        kernel_size=3,
        strides=2,
        activation=activation,
        padding="same",
        name=name,
    )
    return tconv_op(inp)


class FCNNModel(otbtf.ModelBase):
    def normalize_inputs(self, inputs):
        return {
            inp_key_p: keras.ops.cast(inputs[inp_key_p], "float32") * 0.01,
            inp_key_xs: keras.ops.cast(inputs[inp_key_xs], "float32") * 0.01,
        }

    def get_outputs(self, normalized_inputs):
        norm_inp_xs = normalized_inputs[inp_key_xs]
        cv_xs = conv(norm_inp_xs, 32, "conv_xs", 1)

        norm_inp_p = normalized_inputs[inp_key_p]
        cv1 = conv(norm_inp_p, 16, "conv1")
        cv2 = conv(cv1, 32, "conv2") + cv_xs
        cv3 = conv(cv2, 64, "conv3")
        cv4 = conv(cv3, 64, "conv4")
        cv1t = tconv(cv4, 64, "conv1t") + cv3
        cv2t = tconv(cv1t, 32, "conv2t") + cv2
        cv3t = tconv(cv2t, 16, "conv3t") + cv1
        cv4t = tconv(cv3t, class_nb, "softmax_layer", activation="softmax")

        argmax_op = otbtf.layers.Argmax()

        return {tgt_key: cv4t, "estimated_labels": argmax_op(cv4t)}


def train(params, ds_train, ds_valid, ds_test):
    model = FCNNModel(dataset_element_spec=ds_train.element_spec)

    # Precision and recall for each class
    metrics = [
        cls(class_id=class_id)
        for class_id in range(class_nb)
        for cls in [keras.metrics.Precision, keras.metrics.Recall]
    ]

    # F1-Score for each class
    metrics += [
        FScore(class_id=class_id, name=f"fscore_cls{class_id}") for class_id in range(class_nb)
    ]

    model.compile(
        loss={tgt_key: keras.losses.CategoricalCrossentropy()},
        optimizer=keras.optimizers.Adam(params.learning_rate),
        metrics={tgt_key: metrics},
    )
    model.summary()
    callbacks = []
    callbacks.append(
        keras.callbacks.ModelCheckpoint(
            params.model, mode="min", save_best_only=True, monitor="val_loss"
        )
    )
    if params.log_dir:
        callbacks.append(keras.callbacks.TensorBoard(log_dir=params.log_dir))
    if params.ckpt_dir:
        callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=params.ckpt_dir))

    # Train the model
    model.fit(
        ds_train,
        epochs=params.epochs,
        validation_data=ds_valid,
        callbacks=callbacks,
    )

    # Final evaluation on the test dataset
    model.load_weights(params.model)
    values = model.evaluate(ds_test, batch_size=params.batch_size, return_dict=True)
    print("\nMetrics over test dataset:")
    for name, value in values.items():
        print(f"{name}\t{value}")


parser = argparse.ArgumentParser(description="Train a FCNN model")
parser.add_argument("--model", required=True, help="model path (.keras file)")
parser.add_argument("--log_dir", help="logs directory")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=0.0002)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--ckpt_dir", help="Directory for checkpoints")
params = parser.parse_args()

ds_train = create_dataset(
    ["/data/train_p_patches.tif"],
    ["/data/train_xs_patches.tif"],
    ["/data/train_labels_patches.tif"],
)
ds_train = ds_train.shuffle(buffer_size=100)

ds_valid = create_dataset(
    ["/data/valid_p_patches.tif"],
    ["/data/valid_xs_patches.tif"],
    ["/data/valid_labels_patches.tif"],
)

ds_test = create_dataset(
    ["/data/test_p_patches.tif"],
    ["/data/test_xs_patches.tif"],
    ["/data/test_labels_patches.tif"],
)

train(params, ds_train, ds_valid, ds_test)

Dig deeper 🚀

Image summary

Use TensorFlow summaries to monitor additional data during the training process.

Question

  • Create a summary writer in the same directory as the TensorBoard logs,
  • Add a new summary showing the softmax output image for the class id 1

Custom callback

One can implement a custom callback using the keras callback API.

Question

  • Implement a custom callback that wraps the softmax output image summary. The callback must generate the image summary after each epoch (override on_epoch_end()).
  • Test your callback and check that the image summaries are displayed in TensorBoard.