Skip to content

Advanced

Introduction

In this section, we reuse the code to train the semantic segmentation model, and we address a few advanced practices that are commonly used by deep learning practitionners.

Pre-trained model weights

In the previous sections, we trained a model from scratch. In this setup, the model weights are set at the beginning by the layers initializers, which define the way to set the initial random weights of the layers (the Keras documentation on this matter is a good read). However, we can decide to initialize the model weights differently, for instance to load a pre-trained model weights.

First, we add a new optional argument in the parser for the pre-trained model to use:

parser.add_argument("--pretrained_model", help="pre-trained model path (.keras file)")

When the pretrained_model parameter is specified, we load the weights before the model compilation:

    if params.pretrained_model:
        print(f"Loading pre-trained model: {params.pretrained_model}")
        model.load_weights(params.pretrained_model)

Question

  • Copy part_3_train.py to part_4_train.py and implement the loading of pre-trained model weights,
  • Run the new code to start the training from previously trained model (/data/model_semseg) with a single epoch just to check that the model training starts from another initial state.

You can monitor the model metrics to see how they evolve, and compare them to the same metrics from the model trained from scratch.

Layers freezing

A common practice in deep learning consist in freezing specific weights of the model during the training. Frozen weights are considered as constant values during the optimization process. Layer freezing is typically used to speed up the training process (less weights to optimize leads to faster computations), or train large models over small datasets. A general rule is the the number of layers to freeze is closely linked with the similarity between the task for which the pre-trained model has been made for, and the task that the network is solving. For instance, if the tasks goals are similar, e.g. perform the semantic segmentation of buildings and in one hand, and the semantic segmentation of buildings and roads, it could be interesting to train only the classifier of the model (i.e. the last layers). When the tasks are very different, the model will likely need to be trained with a larger number of layers unfreezed.

We start by adding a new parameter to the parser. This parameter will contain a list of regular expressions (regex). We make is optional.

parser.add_argument("--frozen_layers", nargs="+", help="regex list", default=[])

This list or regex will be used to select the layers to freeze.

Warning

It does not make a lot of sense to freeze untrained weights. Usually, weights are randomly initialized, and their output is meaningless in the initial state. That is why you have to append the code only when a pre-trained model is used to initialize the model weights.

    # Pre-trained model init. and layer freeze
    if params.pretrained_model:
        print(f"Loading pre-trained model: {params.pretrained_model}")
        model.load_weights(params.pretrained_model)
        for layer in model.model.layers:
            if any([re.fullmatch(pattern, layer.name) for pattern in params.frozen_layers]):
                print(f"Freezing {layer.name}")
                layer.trainable = False

After the modifications, you can freeze the weights using one or multiple regular expressions. Here are a few examples:

  • --freeze_layers 'conv(1|5)' 'convtranspose.*' will freeze conv1, conv5, and all layers starting with convtranspose
  • --freeze_layers 'conv[1-3]' will freeze conv1, conv2 and conv3 layers.

Question

  • Implement the selection of frozen layers
Solution
from mymetrics import FScore
import otbtf
import keras
import argparse
import re


class_nb = 4  # number of classes
inp_key_p = "input_p"  # model input p
inp_key_xs = "input_xs"  # model input xs
tgt_key = "estimated"  # model target


def create_otbtf_dataset(p, xs, labels):
    return otbtf.DatasetFromPatchesImages(filenames_dict={"p": p, "xs": xs, "labels": labels})


def dataset_preprocessing_fn(sample):
    return {
        inp_key_p: sample["p"],
        inp_key_xs: sample["xs"],
        tgt_key: otbtf.ops.one_hot(labels=sample["labels"], nb_classes=class_nb),
    }


def create_dataset(p, xs, labels, batch_size=8):
    otbtf_dataset = create_otbtf_dataset(p, xs, labels)
    return otbtf_dataset.get_tf_dataset(
        batch_size=batch_size,
        preprocessing_fn=dataset_preprocessing_fn,
        targets_keys=[tgt_key],
    )


def conv(inp, depth, name, strides=2):
    conv_op = keras.layers.Conv2D(
        filters=depth,
        kernel_size=3,
        strides=strides,
        activation="relu",
        padding="same",
        name=name,
    )
    return conv_op(inp)


def tconv(inp, depth, name, activation="relu"):
    tconv_op = keras.layers.Conv2DTranspose(
        filters=depth,
        kernel_size=3,
        strides=2,
        activation=activation,
        padding="same",
        name=name,
    )
    return tconv_op(inp)


class FCNNModel(otbtf.ModelBase):
    def normalize_inputs(self, inputs):
        return {
            inp_key_p: keras.ops.cast(inputs[inp_key_p], "float32") * 0.01,
            inp_key_xs: keras.ops.cast(inputs[inp_key_xs], "float32") * 0.01,
        }

    def get_outputs(self, normalized_inputs):
        norm_inp_xs = normalized_inputs[inp_key_xs]
        cv_xs = conv(norm_inp_xs, 32, "conv_xs", 1)

        norm_inp_p = normalized_inputs[inp_key_p]
        cv1 = conv(norm_inp_p, 16, "conv1")
        cv2 = conv(cv1, 32, "conv2") + cv_xs
        cv3 = conv(cv2, 64, "conv3")
        cv4 = conv(cv3, 64, "conv4")
        cv1t = tconv(cv4, 64, "conv1t") + cv3
        cv2t = tconv(cv1t, 32, "conv2t") + cv2
        cv3t = tconv(cv2t, 16, "conv3t") + cv1
        cv4t = tconv(cv3t, class_nb, "softmax_layer", activation="softmax")

        argmax_op = otbtf.layers.Argmax()

        return {tgt_key: cv4t, "estimated_labels": argmax_op(cv4t)}


def train(params, ds_train, ds_valid, ds_test):
    model = FCNNModel(dataset_element_spec=ds_train.element_spec)

    # Precision and recall for each class
    metrics = [
        cls(class_id=class_id)
        for class_id in range(class_nb)
        for cls in [keras.metrics.Precision, keras.metrics.Recall]
    ]

    # F1-Score for each class
    metrics += [
        FScore(class_id=class_id, name=f"fscore_cls{class_id}") for class_id in range(class_nb)
    ]

    # Pre-trained model init. and layer freeze
    if params.pretrained_model:
        print(f"Loading pre-trained model: {params.pretrained_model}")
        model.load_weights(params.pretrained_model)
        for layer in model.model.layers:
            if any([re.fullmatch(pattern, layer.name) for pattern in params.frozen_layers]):
                print(f"Freezing {layer.name}")
                layer.trainable = False

    model.compile(
        loss={tgt_key: keras.losses.CategoricalCrossentropy()},
        optimizer=keras.optimizers.Adam(params.learning_rate),
        metrics={tgt_key: metrics},
    )
    model.summary()
    callbacks = []
    callbacks.append(
        keras.callbacks.ModelCheckpoint(
            params.model, mode="min", save_best_only=True, monitor="val_loss"
        )
    )
    if params.log_dir:
        callbacks.append(keras.callbacks.TensorBoard(log_dir=params.log_dir))
    if params.ckpt_dir:
        callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=params.ckpt_dir))

    # Train the model
    model.fit(
        ds_train,
        epochs=params.epochs,
        validation_data=ds_valid,
        callbacks=callbacks,
    )

    # Evaluation on the test dataset
    model.load_weights(params.model)
    values = model.evaluate(ds_test, batch_size=params.batch_size, return_dict=True)
    print("\nMetrics over test dataset:")
    for name, value in values.items():
        print(f"{name}\t{value}")


parser = argparse.ArgumentParser(description="Train a FCNN model")
parser.add_argument("--model", required=True, help="model path (.keras file)")
parser.add_argument("--log_dir", help="logs directory")
parser.add_argument("--pretrained_model", help="pre-trained model path (.keras file)")
parser.add_argument("--frozen_layers", nargs="+", help="regex list", default=[])
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=0.0002)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--ckpt_dir", help="Directory for checkpoints")
params = parser.parse_args()

ds_train = create_dataset(
    ["/data/train_p_patches.tif"],
    ["/data/train_xs_patches.tif"],
    ["/data/train_labels_patches.tif"],
)
ds_train = ds_train.shuffle(buffer_size=100)

ds_valid = create_dataset(
    ["/data/valid_p_patches.tif"],
    ["/data/valid_xs_patches.tif"],
    ["/data/valid_labels_patches.tif"],
)

ds_test = create_dataset(
    ["/data/test_p_patches.tif"],
    ["/data/test_xs_patches.tif"],
    ["/data/test_labels_patches.tif"],
)

train(params, ds_train, ds_valid, ds_test)

Data augmentation

Data augmentation is a method enabling to "enlarge" a dataset with artificial samples derived from the original data. Artificial samples can be synthetic (e.g. generated by generative networks) or augmented (i.e. consisting of random transformations of the original samples). In deep learning, people often use augmented samples because they are closer to the real ones. Generally, the larger the dataset is, the better the model. And gathering a lot of samples has a cost. Data augmentation enables to quickly increase the number of samples in a dataset, without too much effort. It brings more variation into the datasets, enabling to train more performant models.

Keras has various image augmentation layers, but is it quite simple to implement your own data augmentation.

In the following exercise, you will implement a random luminosity change of the panchromatic and the multispectral images.

Question

  • Implement augment(sample) with sample a dictionary of pan, xs and tt patches as we have manipulated in dataset_preprocessing_fn(). The function should return the transformed sample (with random transformations of pan and xs but tt should stay unmodified).
  • Modify your training pipeline such as only the training dataset is augmented,
  • Train your model with and without the data augmentation, and observe the dynamic of the summarized metrics.

Partial weight loading

If you have modified your model, for instance by adding a new layer (with weights) or by changing the shape of the weights of a layer, you can choose to ignore errors and continue loading by setting skip_mismatch=True and by_name=True in model.load_weights(...). In this case any layer with mismatching weights will be skipped.

...
model.load_weights(params.pretrained_model, skip_mismatch=True, by_name=True)
...

Warning

An important point is that the model to load must be saved with .h5 extension because the TensorFlow SavedModel format does not allow by_name=True and skip_mismatch=True.

Question

  • Change the number of classes, and add extra layers before the softmax,
  • Use the trained model from the Semantic segmentation part, as a pretrained model to load weights from,
  • Train the new model (all layers, then only end layers), compare the processing time

TFRecords

Training very large datasets efficiently is hard to perform. On high performance hardware, an I/O bottleneck often happen when reading Geotiff images, and faster file formats and other paradigms have to be employed.

TFRecords are files storing a number of samples, that are efficient to build large TensorFlow datasets. They are commonly used in distributed environments to train models using large amount of data. OTBTF includes components to convert patches-images datasets into pre-shuffled TFRecords based datasets, and to read TFRecords directly as a TensorFlow dataset.

Question

  • Read the otbtf documentation and convert your patches images into TFRecords,
  • Adapt your training setup to use TFRecords instead of patches images based datasets