Skip to content

Building the model

We will now build a small U-Net based model with 4 downscaling/upscaling levels.

Architecture

The following figure presents the architecture of our model.

flowchart TD
tt((terrain truth)) --> loss

x1((pan)) --> n1[normalization] -- 64x64x1 --> c11[Conv 3x3, stride 2, ReLU]
c11 -- 32x32x16 --> c21[Conv 3x3, stride 2, ReLU]

x2((xs)) --> n2[normalization] -- 16x16x4 --> c12[Conv 3x3, stride 1, ReLU]

c21 -- 16x16x32 --> plus1((+))
c12 -- 16x16x32 --> plus1((+))

plus1((+)) --> c3[Conv 3x3, stride 2, ReLU]

c3 -- 8x8x64 --> c4[Conv 3x3, stride 2, ReLU]

c4 -- 4x4x64 --> c1t[Transposed Conv 3x3, stride 2, ReLU]
c1t -- 8x8x64 --> plus2((+))
c3 --> plus2((+))
plus2((+)) --> c2t[Transposed Conv 3x3, stride 2, ReLU]
plus1((+)) --> plus3((+))
plus3((+)) --> c3t[Transposed Conv 3x3, stride 2, ReLU]
c2t -- 16x16x32 --> plus3((+))
c11 --> plus4((+))
c3t --> plus4((+))
plus4((+)) -- 32x32x16 --> c4t[Transposed Conv 3x3, stride 2, Softmax]

c4t -- 64x64xN --> argmax -- 1x1x1 --> out((labels))
c4t --> loss[Cross entropy]
loss --> Optimizer

With N being the number of classes (4 in our case).

Implementation

We proceed in the same way that we did in the patch-based classification section.

Constants

First we define some constants:

class_nb = 4  # number of classes
inp_key_p = "input_p"  # model input p
inp_key_xs = "input_xs"  # model input xs
tgt_key = "estimated"  # model target

Dataset helpers

Then we define a few helpers to help building the datasets:

A helper to create otbtf dataset from lists of patches
def create_otbtf_dataset(p, xs, labels):
    return otbtf.DatasetFromPatchesImages(filenames_dict={"p": p, "xs": xs, "labels": labels})
Dataset preprocessing function
        inp_key_p: sample["p"],
        inp_key_xs: sample["xs"],
        tgt_key: otbtf.ops.one_hot(labels=sample["labels"], nb_classes=class_nb),
    }
TensorFlow dataset creation from lists of patches
    return otbtf_dataset.get_tf_dataset(
        batch_size=batch_size,
        preprocessing_fn=dataset_preprocessing_fn,
        targets_keys=[tgt_key],
    )

The datasets can be instantiated using the helpers:

Datasets instantiation
ds_train = create_dataset(
    ["/data/train_p_patches.tif"],
    ["/data/train_xs_patches.tif"],
    ["/data/train_labels_patches.tif"],
)
ds_train = ds_train.shuffle(buffer_size=100)

ds_valid = create_dataset(
    ["/data/valid_p_patches.tif"],
    ["/data/valid_xs_patches.tif"],
    ["/data/valid_labels_patches.tif"],
)

ds_test = create_dataset(
    ["/data/test_p_patches.tif"],
    ["/data/test_xs_patches.tif"],
    ["/data/test_labels_patches.tif"],
)

Operators

We define a convolution operator with 2 strides as default. Note that unlike for patch-based classification we use padding here, to have output tensors with preserved spatial size. This is to help in building the network in a straightforward fashion:

Convolution operator
        filters=depth,
        kernel_size=3,
        strides=strides,
        activation="relu",
        padding="same",
        name=name,
    )
    return conv_op(inp)

Same for the transposed convolution, that will upsample the tensors in the spatial dimensions:

Transposed convolution operator
        filters=depth,
        kernel_size=3,
        strides=2,
        activation=activation,
        padding="same",
        name=name,
    )
    return tconv_op(inp)

Model

The model is built following the previously detailed architecture:

        return {
            inp_key_p: keras.ops.cast(inputs[inp_key_p], "float32") * 0.01,
            inp_key_xs: keras.ops.cast(inputs[inp_key_xs], "float32") * 0.01,
        }

    def get_outputs(self, normalized_inputs):
        norm_inp_xs = normalized_inputs[inp_key_xs]
        cv_xs = conv(norm_inp_xs, 32, "conv_xs", 1)

        norm_inp_p = normalized_inputs[inp_key_p]
        cv1 = conv(norm_inp_p, 16, "conv1")
        cv2 = conv(cv1, 32, "conv2") + cv_xs
        cv3 = conv(cv2, 64, "conv3")
        cv4 = conv(cv3, 64, "conv4")
        cv1t = tconv(cv4, 64, "conv1t") + cv3
        cv2t = tconv(cv1t, 32, "conv2t") + cv2
        cv3t = tconv(cv2t, 16, "conv3t") + cv1
        cv4t = tconv(cv3t, class_nb, "softmax_layer", activation="softmax")

        argmax_op = otbtf.layers.Argmax()

        return {tgt_key: cv4t, "estimated_labels": argmax_op(cv4t)}

Question

  • Create part_3_train.py and implement the model,
  • Try to implement the training setup on your own. You can start from a code written in previous section, e.g. part_2_train_fcn.py.

Note

The solution is given in the next section!