Skip to content

Build and train a model

In this section will build and train a small model, that inputs patches of the 10m spacing Sentinel-2 image, and estimates the class of the center pixel. Our CNN inputs the 16x16x4 patches we have previously generated, and generates the prediction of an output class among 6 labels ranging from 0 to 5 (see table in the Terrain truth section).

The principle of the model training is to use the training dataset for the optimization of the network, regarding the minimization of the loss function. The validation dataset is used to compute the average loss across one epoch, in order to select the best model to export, during the training.

Let's start

Create a new python code, part_1_training.py. We start with importing the necessary python modules:

import argparse
import otbtf
import keras

And we define a few constants that will be used multiple times:

class_nb = 6  # number of classes
inp_key = "input"  # model input
tgt_key = "estimated"  # model target

Tensorflow datasets

We first build training and validation datasets, of which our model will consume samples during the training. To do that, we use the otbtf.DatasetFromPatchesImages class to convert our previously generated GeoTiff patches into a tensorflow.Dataset object, i.e. something that Keras/TensorFlow is able to use.

Also, we need to convert the input and the target in the right format. While the input can be provided as it is, the target needs to be transformed into one hot encoding, required by the cross entropy loss operator. We build a preprocessing function named dataset_preprocessing(), that inputs one single sample, and returns a dictionary of keys/values with keys being the name for inputs and targets, and values being the tensors to provide to the model and loss operator.

To avoid doing things twice, one time for the training dataset, and another for the validation dataset, we build a helper function, named create_dataset() that builds for us the TensorFlow dataset from the list of geotiff patches files.

def dataset_preprocessing_fn(sample):
    """The preprocessing function transforms labels in one hot encoding"""
    return {
        inp_key: sample["img"],
        tgt_key: otbtf.ops.one_hot(labels=sample["labels"], nb_classes=class_nb),
    }


def create_dataset(img, labels, batch_size=8):
    """This function returns a TensorFlow dataset"""
    otbtf_dataset = otbtf.DatasetFromPatchesImages(
        filenames_dict={"img": img, "labels": labels}
    )
    return otbtf_dataset.get_tf_dataset(
        batch_size=batch_size,
        preprocessing_fn=dataset_preprocessing_fn,
        targets_keys=[tgt_key],
    )

We can then call this helper to build the datasets for training and validation:

# Training dataset
ds_train = create_dataset(["/data/a_img_10m.tif"], ["/data/a_labels.tif"])
ds_train = ds_train.shuffle(buffer_size=100)

# Validation dataset
ds_valid = create_dataset(["/data/b_img_10m.tif"], ["/data/b_labels.tif"])

Note

The tranining samples are shuffled each epoch to better represent the loss gradient. If the order of samples is always the same, the risk is that the gradient descent is stuck in a local minimum while a better solution might exist in the neighborhood. Shuffling the batches of samples is a solution to avoid that.

Custom operators

We will build our model with convolutions operators and max pooling.

Convolution

The convolution has a REctified Linear Unit (RELU) as activation function. It performs convolution with unitary strides. The kernel size and the depth can be configured with the arguments. The convolution does not perform any padding and generates only the valid part of the output.

def conv(inp, depth, kernel_size, name, activation="relu"):
    conv_op = keras.layers.Conv2D(
        filters=depth,
        kernel_size=kernel_size,
        strides=1,
        activation=activation,
        padding="valid",
        name=name,
    )
    return conv_op(inp)

Max pooling

We use a 2x2 max pooling operator, with strides (2, 2) by default.

pool = keras.layers.MaxPool2D(pool_size=(2, 2))

Architecture

To build our model, we can build from scratch building on keras.Model, but we will see how OTBTF helps a lot with the otbtf.BaseModel class.

class SimpleCNNModel(otbtf.ModelBase):
    """ " This is a subclass of `otbtf.ModelBase` to implement a CNN"""

Overview

We build a small convolutional neural network consisting of a 4 trainable convolutions layers and 2 max pooling layers (not trainable). Our model inputs 16x16x4 patches (patches of 16x16 pixels, each pixel having 4 components) and produce a 1x1x6 output for the pseudo-probability distribution of classes (estimated), and a 1x1x1 output for the estimated class label (argmax_layer)

flowchart TD

i((input)) --> normalization -- 16x16x4 --> c1[conv 5x5 + ReLU]
c1 -- 12x12x16 --> p1[Max Pooling 2x2] -- 6x6x16 --> c2[conv 3x3 + ReLU]
c2 -- 4x4x32 --> p2[Max Pooling 2x2] -- 2x2x32 --> c3[conv 2x2 + ReLU]
c3 -- 1x1x64 --> c4[conv 1x1 + Softmax]
c4 -- 1x1x6 --> loss[Softmax cross entropy] --> optimizer
c4 --> argmax -- 1x1x1 --> p((labels))
tt((Terrain truth)) --> loss

Input normalization

The normalization is implemented in SimpleCNNModel.normalize_intputs().

The model is intended to work on real world images, which have often 16 bits signed integers as pixel values. The model has to normalize these values such as they fit the [0, 1] range before reaching the convolutions. To do that, we will use a simple scaling of the image pixels values.

    def normalize_inputs(self, inputs):
        """This function nomalizes the input, scaling values by 1e-4"""
        return {inp_key: keras.ops.cast(inputs[inp_key], "float32") * 1e-4}

Network implementation

The actual network is implemented in SimpleCNNModel.get_output().

The normalized input is then processed by a succession of 2D-Convolution + activation function (conv1, conv2, feats) and pooling layers (pool1, pool2). The features after feats are processed by a 1x1 layer of 6 neurons (one for each predicted class) with a softmax activation function. The softmax function normalizes the output such as their sum is equal to 1 and can be used to represent a categorical distribution, i.e, a probability distribution over n different possible outcomes. The softmax layer is named softmax_layer in our model.

The estimated class is the index of the neuron (from the last neuron layer) that output the maximum value. This is performed in processing the outputs of softmax_layer with the otbtf.layers.Argmax operator.

Finally, the get_output() looks like:

    def get_outputs(self, normalized_inputs):
        """This function implements the model"""
        inp = normalized_inputs[inp_key]
        net = conv(inp, 16, 5, "conv1")  # 12x12x16
        net = pool(net)  # 6x6x16
        net = conv(net, 32, 3, "conv2")  # 4x4x32
        net = pool(net)  # 2x2x32
        net = conv(net, 64, 2, "feats")  # 1x1x32

        # Classifier
        estim = conv(net, class_nb, 1, "softmax_layer", activation="softmax")
        argmax_op = otbtf.layers.Argmax()

        return {
            tgt_key: estim,
            "estimated_labels": argmax_op(estim),  # additional output: class id
            "features": net,  # additional output: features
        }

Optimization

We can now train our model!

Scope

First, we instantiate our model.

model = SimpleCNNModel(dataset_element_spec=ds_train.element_spec)

Compilation

Then, we specify the loss and the optimizer:

model.compile(
    loss={tgt_key: keras.losses.CategoricalCrossentropy()},
    optimizer=keras.optimizers.Adam(params.learning_rate),
)

We provide to loss a dictionary where the key is the name of the model output, and the value is the function used to compute the loss.

Warning

The model output and the training/validation datasets targets names must match so that keras is able to compute the loss.

We detail in the following subsections the role of loss and optimizer.

Loss

The goal of the training is to minimize the cross-entropy between the data distribution (real labels) and the model distribution (estimated labels). Our loss function is the cross-entropy between the softmax of the 6 neurons returned by our model tgt_key output, and the training dataset labels, one-hot encoded. The categorical cross-entropy measures the probability error in the discrete classification task in which the classes are mutually exclusive (each entry is in exactly one class).

The categorical cross-entropy is implemented in keras in keras.losses.CategoricalCrossentropy().

Optimizer

The optimizer performs the minimisation of the loss. This operator is used at training time only, it is not used during inference. We use an optimizer that employs the Adaptive moment estimation method. We use the implementation provided in keras by keras.optimizers.Adam().

Summary

After compiling it, we can summarize he model like this:

model.summary()

Fitting

Now we call fit(), like any keras model (you can read more on the API in the keras documentation). This will run the actual training of the model.

model.fit(
    ds_train,
    epochs=params.epochs,
    validation_data=ds_valid,
    callbacks=[keras.callbacks.ModelCheckpoint(params.model)],
)

Here we use the keras.callbacks.ModelCheckpoint(params.model_dir) callback that will save for us the best model. The ds_valid validation dataset will be used after each epoch to check if the loss computed over the validation dataset has decreased (if yes, a copy of the model is exported in the params.model file). The number of epochs is set with the epochs argument.

Question

Create the python script part_1_train.py. We want that it can be run in command line. Use the argparse module to build a simple parser, so that we can provide the following arguments:

  • model: the file to save the trained model
  • batch_size: the batch size
  • learning_rate: the learning rate
  • epochs: the number of epochs
parser = argparse.ArgumentParser(description="Train a CNN model")
parser.add_argument("--model", required=True, help="model path (.keras file)")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=0.0002)
parser.add_argument("--epochs", type=int, default=100)
params = parser.parse_args()

Put the pieces together, and run your code using the following command line:

python part_1_train.py --model /data/models/model1.keras
Solution
import argparse
import otbtf
import keras


class_nb = 6  # number of classes
inp_key = "input"  # model input
tgt_key = "estimated"  # model target


def dataset_preprocessing_fn(sample):
    """The preprocessing function transforms labels in one hot encoding"""
    return {
        inp_key: sample["img"],
        tgt_key: otbtf.ops.one_hot(labels=sample["labels"], nb_classes=class_nb),
    }


def create_dataset(img, labels, batch_size=8):
    """This function returns a TensorFlow dataset"""
    otbtf_dataset = otbtf.DatasetFromPatchesImages(
        filenames_dict={"img": img, "labels": labels}
    )
    return otbtf_dataset.get_tf_dataset(
        batch_size=batch_size,
        preprocessing_fn=dataset_preprocessing_fn,
        targets_keys=[tgt_key],
    )


def conv(inp, depth, kernel_size, name, activation="relu"):
    conv_op = keras.layers.Conv2D(
        filters=depth,
        kernel_size=kernel_size,
        strides=1,
        activation=activation,
        padding="valid",
        name=name,
    )
    return conv_op(inp)


pool = keras.layers.MaxPool2D(pool_size=(2, 2))


class SimpleCNNModel(otbtf.ModelBase):
    """ " This is a subclass of `otbtf.ModelBase` to implement a CNN"""

    def normalize_inputs(self, inputs):
        """This function nomalizes the input, scaling values by 1e-4"""
        return {inp_key: keras.ops.cast(inputs[inp_key], "float32") * 1e-4}

    def get_outputs(self, normalized_inputs):
        """This function implements the model"""
        inp = normalized_inputs[inp_key]
        net = conv(inp, 16, 5, "conv1")  # 12x12x16
        net = pool(net)  # 6x6x16
        net = conv(net, 32, 3, "conv2")  # 4x4x32
        net = pool(net)  # 2x2x32
        net = conv(net, 64, 2, "feats")  # 1x1x32

        # Classifier
        estim = conv(net, class_nb, 1, "softmax_layer", activation="softmax")
        argmax_op = otbtf.layers.Argmax()

        return {
            tgt_key: estim,
            "estimated_labels": argmax_op(estim),  # additional output: class id
            "features": net,  # additional output: features
        }


parser = argparse.ArgumentParser(description="Train a CNN model")
parser.add_argument("--model", required=True, help="model path (.keras file)")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=0.0002)
parser.add_argument("--epochs", type=int, default=100)
params = parser.parse_args()

# Training dataset
ds_train = create_dataset(["/data/a_img_10m.tif"], ["/data/a_labels.tif"])
ds_train = ds_train.shuffle(buffer_size=100)

# Validation dataset
ds_valid = create_dataset(["/data/b_img_10m.tif"], ["/data/b_labels.tif"])

model = SimpleCNNModel(dataset_element_spec=ds_train.element_spec)
model.compile(
    loss={tgt_key: keras.losses.CategoricalCrossentropy()},
    optimizer=keras.optimizers.Adam(params.learning_rate),
)
model.summary()
model.fit(
    ds_train,
    epochs=params.epochs,
    validation_data=ds_valid,
    callbacks=[keras.callbacks.ModelCheckpoint(params.model)],
)