Build and train a model
In this section will build and train a small model, that inputs patches of the 10m spacing Sentinel-2 image, and estimates the class of the center pixel. Our CNN inputs the 16x16x4 patches we have previously generated, and generates the prediction of an output class among 6 labels ranging from 0 to 5 (see table in the Terrain truth section).
The principle of the model training is to use the training dataset for the optimization of the network, regarding the minimization of the loss function. The validation dataset is used to compute the average loss across one epoch, in order to select the best model to export, during the training.
Let's start¶
Create a new python code, part_1_training.py. We start with importing the
necessary python modules:
And we define a few constants that will be used multiple times:
class_nb = 6 # number of classes
inp_key = "input" # model input
tgt_key = "estimated" # model target
Tensorflow datasets¶
We first build training and validation datasets, of which our model will
consume samples during the training. To do that, we use the
otbtf.DatasetFromPatchesImages class to convert our previously generated
GeoTiff patches into a tensorflow.Dataset object, i.e. something that
Keras/TensorFlow is able to use.
Also, we need to convert the input and the target in the right format.
While the input can be provided as it is, the target needs to be
transformed into one hot encoding, required by the cross entropy
loss operator.
We build a preprocessing function named dataset_preprocessing(),
that inputs one single sample, and returns a dictionary of
keys/values with keys being the name for inputs and targets, and
values being the tensors to provide to the model and loss operator.
To avoid doing things twice, one time for the training dataset, and
another for the validation dataset, we build a helper function, named
create_dataset() that builds for us the TensorFlow dataset from the
list of geotiff patches files.
def dataset_preprocessing_fn(sample):
"""The preprocessing function transforms labels in one hot encoding"""
return {
inp_key: sample["img"],
tgt_key: otbtf.ops.one_hot(labels=sample["labels"], nb_classes=class_nb),
}
def create_dataset(img, labels, batch_size=8):
"""This function returns a TensorFlow dataset"""
otbtf_dataset = otbtf.DatasetFromPatchesImages(
filenames_dict={"img": img, "labels": labels}
)
return otbtf_dataset.get_tf_dataset(
batch_size=batch_size,
preprocessing_fn=dataset_preprocessing_fn,
targets_keys=[tgt_key],
)
We can then call this helper to build the datasets for training and validation:
# Training dataset
ds_train = create_dataset(["/data/a_img_10m.tif"], ["/data/a_labels.tif"])
ds_train = ds_train.shuffle(buffer_size=100)
# Validation dataset
ds_valid = create_dataset(["/data/b_img_10m.tif"], ["/data/b_labels.tif"])
Note
The tranining samples are shuffled each epoch to better represent the loss gradient. If the order of samples is always the same, the risk is that the gradient descent is stuck in a local minimum while a better solution might exist in the neighborhood. Shuffling the batches of samples is a solution to avoid that.
Custom operators¶
We will build our model with convolutions operators and max pooling.
Convolution¶
The convolution has a REctified Linear Unit (RELU) as activation function. It performs convolution with unitary strides. The kernel size and the depth can be configured with the arguments. The convolution does not perform any padding and generates only the valid part of the output.
def conv(inp, depth, kernel_size, name, activation="relu"):
conv_op = keras.layers.Conv2D(
filters=depth,
kernel_size=kernel_size,
strides=1,
activation=activation,
padding="valid",
name=name,
)
return conv_op(inp)
Max pooling¶
We use a 2x2 max pooling operator, with strides (2, 2) by default.
Architecture¶
To build our model, we can build from scratch building on
keras.Model, but we will see how OTBTF helps a lot with the
otbtf.BaseModel class.
class SimpleCNNModel(otbtf.ModelBase):
""" " This is a subclass of `otbtf.ModelBase` to implement a CNN"""
Overview¶
We build a small convolutional neural network consisting of a 4 trainable convolutions layers and 2 max pooling layers (not trainable). Our model inputs 16x16x4 patches (patches of 16x16 pixels, each pixel having 4 components) and produce a 1x1x6 output for the pseudo-probability distribution of classes (estimated), and a 1x1x1 output for the estimated class label (argmax_layer)
flowchart TD
i((input)) --> normalization -- 16x16x4 --> c1[conv 5x5 + ReLU]
c1 -- 12x12x16 --> p1[Max Pooling 2x2] -- 6x6x16 --> c2[conv 3x3 + ReLU]
c2 -- 4x4x32 --> p2[Max Pooling 2x2] -- 2x2x32 --> c3[conv 2x2 + ReLU]
c3 -- 1x1x64 --> c4[conv 1x1 + Softmax]
c4 -- 1x1x6 --> loss[Softmax cross entropy] --> optimizer
c4 --> argmax -- 1x1x1 --> p((labels))
tt((Terrain truth)) --> loss
Input normalization¶
The normalization is implemented in SimpleCNNModel.normalize_intputs().
The model is intended to work on real world images, which have often 16 bits signed integers as pixel values. The model has to normalize these values such as they fit the [0, 1] range before reaching the convolutions. To do that, we will use a simple scaling of the image pixels values.
def normalize_inputs(self, inputs):
"""This function nomalizes the input, scaling values by 1e-4"""
return {inp_key: keras.ops.cast(inputs[inp_key], "float32") * 1e-4}
Network implementation¶
The actual network is implemented in SimpleCNNModel.get_output().
The normalized input is then processed by a succession of 2D-Convolution + activation function (conv1, conv2, feats) and pooling layers (pool1, pool2). The features after feats are processed by a 1x1 layer of 6 neurons (one for each predicted class) with a softmax activation function. The softmax function normalizes the output such as their sum is equal to 1 and can be used to represent a categorical distribution, i.e, a probability distribution over n different possible outcomes. The softmax layer is named softmax_layer in our model.
The estimated class is the index of the neuron (from the last neuron
layer) that output the maximum value. This is performed in processing
the outputs of softmax_layer with the otbtf.layers.Argmax
operator.
Finally, the get_output() looks like:
def get_outputs(self, normalized_inputs):
"""This function implements the model"""
inp = normalized_inputs[inp_key]
net = conv(inp, 16, 5, "conv1") # 12x12x16
net = pool(net) # 6x6x16
net = conv(net, 32, 3, "conv2") # 4x4x32
net = pool(net) # 2x2x32
net = conv(net, 64, 2, "feats") # 1x1x32
# Classifier
estim = conv(net, class_nb, 1, "softmax_layer", activation="softmax")
argmax_op = otbtf.layers.Argmax()
return {
tgt_key: estim,
"estimated_labels": argmax_op(estim), # additional output: class id
"features": net, # additional output: features
}
Optimization¶
We can now train our model!
Scope¶
First, we instantiate our model.
Compilation¶
Then, we specify the loss and the optimizer:
model.compile(
loss={tgt_key: keras.losses.CategoricalCrossentropy()},
optimizer=keras.optimizers.Adam(params.learning_rate),
)
We provide to loss a dictionary where the key is the name of the model
output, and the value is the function used to compute the loss.
Warning
The model output and the training/validation datasets targets names must match so that keras is able to compute the loss.
We detail in the following subsections the role of loss and optimizer.
Loss¶
The goal of the training is to minimize the cross-entropy between the
data distribution (real labels) and the model distribution (estimated
labels). Our loss function is the cross-entropy between the softmax of
the 6 neurons returned by our model tgt_key output, and the training
dataset labels, one-hot encoded.
The categorical cross-entropy measures the probability error in the
discrete classification task in which the classes are mutually exclusive
(each entry is in exactly one class).
The categorical cross-entropy is implemented in keras in
keras.losses.CategoricalCrossentropy().
Optimizer¶
The optimizer performs the minimisation of the loss. This operator is used at
training time only, it is not used during inference. We use an optimizer that
employs the Adaptive moment estimation
method. We use the implementation provided in keras by
keras.optimizers.Adam().
Summary¶
After compiling it, we can summarize he model like this:
Fitting¶
Now we call fit(), like any keras model (you can read more on the API
in the keras documentation).
This will run the actual training of the model.
model.fit(
ds_train,
epochs=params.epochs,
validation_data=ds_valid,
callbacks=[keras.callbacks.ModelCheckpoint(params.model)],
)
Here we use the keras.callbacks.ModelCheckpoint(params.model_dir) callback
that will save for us the best model. The ds_valid validation dataset will be
used after each epoch to check if the loss computed over the validation dataset
has decreased (if yes, a copy of the model is exported in the
params.model file). The number of epochs is set with the epochs
argument.
Question
Create the python script part_1_train.py. We want that it can be run in
command line. Use the argparse module to build a simple parser, so that we
can provide the following arguments:
model: the file to save the trained modelbatch_size: the batch sizelearning_rate: the learning rateepochs: the number of epochs
parser = argparse.ArgumentParser(description="Train a CNN model")
parser.add_argument("--model", required=True, help="model path (.keras file)")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=0.0002)
parser.add_argument("--epochs", type=int, default=100)
params = parser.parse_args()
Put the pieces together, and run your code using the following command line:
Solution
import argparse
import otbtf
import keras
class_nb = 6 # number of classes
inp_key = "input" # model input
tgt_key = "estimated" # model target
def dataset_preprocessing_fn(sample):
"""The preprocessing function transforms labels in one hot encoding"""
return {
inp_key: sample["img"],
tgt_key: otbtf.ops.one_hot(labels=sample["labels"], nb_classes=class_nb),
}
def create_dataset(img, labels, batch_size=8):
"""This function returns a TensorFlow dataset"""
otbtf_dataset = otbtf.DatasetFromPatchesImages(
filenames_dict={"img": img, "labels": labels}
)
return otbtf_dataset.get_tf_dataset(
batch_size=batch_size,
preprocessing_fn=dataset_preprocessing_fn,
targets_keys=[tgt_key],
)
def conv(inp, depth, kernel_size, name, activation="relu"):
conv_op = keras.layers.Conv2D(
filters=depth,
kernel_size=kernel_size,
strides=1,
activation=activation,
padding="valid",
name=name,
)
return conv_op(inp)
pool = keras.layers.MaxPool2D(pool_size=(2, 2))
class SimpleCNNModel(otbtf.ModelBase):
""" " This is a subclass of `otbtf.ModelBase` to implement a CNN"""
def normalize_inputs(self, inputs):
"""This function nomalizes the input, scaling values by 1e-4"""
return {inp_key: keras.ops.cast(inputs[inp_key], "float32") * 1e-4}
def get_outputs(self, normalized_inputs):
"""This function implements the model"""
inp = normalized_inputs[inp_key]
net = conv(inp, 16, 5, "conv1") # 12x12x16
net = pool(net) # 6x6x16
net = conv(net, 32, 3, "conv2") # 4x4x32
net = pool(net) # 2x2x32
net = conv(net, 64, 2, "feats") # 1x1x32
# Classifier
estim = conv(net, class_nb, 1, "softmax_layer", activation="softmax")
argmax_op = otbtf.layers.Argmax()
return {
tgt_key: estim,
"estimated_labels": argmax_op(estim), # additional output: class id
"features": net, # additional output: features
}
parser = argparse.ArgumentParser(description="Train a CNN model")
parser.add_argument("--model", required=True, help="model path (.keras file)")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=0.0002)
parser.add_argument("--epochs", type=int, default=100)
params = parser.parse_args()
# Training dataset
ds_train = create_dataset(["/data/a_img_10m.tif"], ["/data/a_labels.tif"])
ds_train = ds_train.shuffle(buffer_size=100)
# Validation dataset
ds_valid = create_dataset(["/data/b_img_10m.tif"], ["/data/b_labels.tif"])
model = SimpleCNNModel(dataset_element_spec=ds_train.element_spec)
model.compile(
loss={tgt_key: keras.losses.CategoricalCrossentropy()},
optimizer=keras.optimizers.Adam(params.learning_rate),
)
model.summary()
model.fit(
ds_train,
epochs=params.epochs,
validation_data=ds_valid,
callbacks=[keras.callbacks.ModelCheckpoint(params.model)],
)