Training setup
We keep the same training setup as in the patch-based classification section.
Question
Create a file named part_3_train.py in which you will implement the
following.
Metrics¶
As addressed in the previous sections, our metrics are Precision, Recall, and
F1-Score (that we have already defined in mymetrics.py).
cls(class_id=class_id)
for class_id in range(class_nb)
for cls in [keras.metrics.Precision, keras.metrics.Recall]
]
# F1-Score for each class
metrics += [
FScore(class_id=class_id, name=f"fscore_cls{class_id}") for class_id in range(class_nb)
]
model.compile(
loss={tgt_key: keras.losses.CategoricalCrossentropy()},
Callbacks¶
The different callbacks to be used during training.
ModelCheckpoint¶
We let keras compute the metrics over validation dataset after each epoch, and we save only the model with the lowest validation loss.
callbacks = []
callbacks.append(
keras.callbacks.ModelCheckpoint(
params.model, mode="min", save_best_only=True, monitor="val_loss"
)
)
TensorBoard¶
We keep a trace of metrics for TensorBoard using the keras built-in
tf.keras.callbacks.TensorBoard class:
This is optional, and if we don't provide the params.log_dir (from
the command line argument parser --ckpt_log argument), it won't be
used.
BackupAndRestore¶
Since the training of our model can be a bit long on CPU, we use a callback to save regularly the model weights, so that the training can be interrupted and restored at any time (e.g. hardware failure, or process killed).
Info
The backup-and-restore callback deletes the checkpoint once the training is
completed, i.e. after the specified number of epochs in fit() has been
performed.
This is optional, and if we don't provide the params.ckpt_dir (from
the command line argument parser --ckpt_dir argument), it won't be
used.
Warning
When the --ckpt_dir argument is provided in the CLI, the training will
always used the weights saved in the provided directory. Do not forget to
delete the directory when you want to restart a training from scratch after
a previously interrupted training.
Question
- Implement the backup-and-restore callback,
- Check that the callback is working during training in interrupting the
process with
ctrl+ckeys, and restarts at the right epoch with the same metrics.
Final evaluation¶
Finally, once the model is properly trained (i.e. saved when the selected metric has reach a new minimum), we compute the metrics over the test dataset, so that we have a proper evaluation of its performances on an independent dataset.
# Final evaluation on the test dataset
model.load_weights(params.model)
values = model.evaluate(ds_test, batch_size=params.batch_size, return_dict=True)
print("\nMetrics over test dataset:")
for name, value in values.items():
print(f"{name}\t{value}")
Question
- Append this code after
model.fit(...), at the end of the training.
You can now run the training using --ckpt_dir to select a directory
to write the checkpoints. Also you can change the --log_dir to select
another directory for the TensorBoard logs. Use --epochs to select the
number of epochs.
python part_3_train.py \
--model /data/models/model3.keras \
--log_dir /data/logs/model3 \
--epochs 50 \
--ckpt_dir /data/ckpts/model3
Question
- Compare the metrics over the validation dataset, the test dataset and the training dataset.
Solution
"""Semantic segmentation of Spot-7 images"""
from mymetrics import FScore
import otbtf
import keras
import argparse
class_nb = 4 # number of classes
inp_key_p = "input_p" # model input p
inp_key_xs = "input_xs" # model input xs
tgt_key = "estimated" # model target
def create_otbtf_dataset(p, xs, labels):
return otbtf.DatasetFromPatchesImages(filenames_dict={"p": p, "xs": xs, "labels": labels})
def dataset_preprocessing_fn(sample):
return {
inp_key_p: sample["p"],
inp_key_xs: sample["xs"],
tgt_key: otbtf.ops.one_hot(labels=sample["labels"], nb_classes=class_nb),
}
def create_dataset(p, xs, labels, batch_size=8):
otbtf_dataset = create_otbtf_dataset(p, xs, labels)
return otbtf_dataset.get_tf_dataset(
batch_size=batch_size,
preprocessing_fn=dataset_preprocessing_fn,
targets_keys=[tgt_key],
)
def conv(inp, depth, name, strides=2):
conv_op = keras.layers.Conv2D(
filters=depth,
kernel_size=3,
strides=strides,
activation="relu",
padding="same",
name=name,
)
return conv_op(inp)
def tconv(inp, depth, name, activation="relu"):
tconv_op = keras.layers.Conv2DTranspose(
filters=depth,
kernel_size=3,
strides=2,
activation=activation,
padding="same",
name=name,
)
return tconv_op(inp)
class FCNNModel(otbtf.ModelBase):
def normalize_inputs(self, inputs):
return {
inp_key_p: keras.ops.cast(inputs[inp_key_p], "float32") * 0.01,
inp_key_xs: keras.ops.cast(inputs[inp_key_xs], "float32") * 0.01,
}
def get_outputs(self, normalized_inputs):
norm_inp_xs = normalized_inputs[inp_key_xs]
cv_xs = conv(norm_inp_xs, 32, "conv_xs", 1)
norm_inp_p = normalized_inputs[inp_key_p]
cv1 = conv(norm_inp_p, 16, "conv1")
cv2 = conv(cv1, 32, "conv2") + cv_xs
cv3 = conv(cv2, 64, "conv3")
cv4 = conv(cv3, 64, "conv4")
cv1t = tconv(cv4, 64, "conv1t") + cv3
cv2t = tconv(cv1t, 32, "conv2t") + cv2
cv3t = tconv(cv2t, 16, "conv3t") + cv1
cv4t = tconv(cv3t, class_nb, "softmax_layer", activation="softmax")
argmax_op = otbtf.layers.Argmax()
return {tgt_key: cv4t, "estimated_labels": argmax_op(cv4t)}
def train(params, ds_train, ds_valid, ds_test):
model = FCNNModel(dataset_element_spec=ds_train.element_spec)
# Precision and recall for each class
metrics = [
cls(class_id=class_id)
for class_id in range(class_nb)
for cls in [keras.metrics.Precision, keras.metrics.Recall]
]
# F1-Score for each class
metrics += [
FScore(class_id=class_id, name=f"fscore_cls{class_id}") for class_id in range(class_nb)
]
model.compile(
loss={tgt_key: keras.losses.CategoricalCrossentropy()},
optimizer=keras.optimizers.Adam(params.learning_rate),
metrics={tgt_key: metrics},
)
model.summary()
callbacks = []
callbacks.append(
keras.callbacks.ModelCheckpoint(
params.model, mode="min", save_best_only=True, monitor="val_loss"
)
)
if params.log_dir:
callbacks.append(keras.callbacks.TensorBoard(log_dir=params.log_dir))
if params.ckpt_dir:
callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=params.ckpt_dir))
# Train the model
model.fit(
ds_train,
epochs=params.epochs,
validation_data=ds_valid,
callbacks=callbacks,
)
# Final evaluation on the test dataset
model.load_weights(params.model)
values = model.evaluate(ds_test, batch_size=params.batch_size, return_dict=True)
print("\nMetrics over test dataset:")
for name, value in values.items():
print(f"{name}\t{value}")
parser = argparse.ArgumentParser(description="Train a FCNN model")
parser.add_argument("--model", required=True, help="model path (.keras file)")
parser.add_argument("--log_dir", help="logs directory")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=0.0002)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--ckpt_dir", help="Directory for checkpoints")
params = parser.parse_args()
ds_train = create_dataset(
["/data/train_p_patches.tif"],
["/data/train_xs_patches.tif"],
["/data/train_labels_patches.tif"],
)
ds_train = ds_train.shuffle(buffer_size=100)
ds_valid = create_dataset(
["/data/valid_p_patches.tif"],
["/data/valid_xs_patches.tif"],
["/data/valid_labels_patches.tif"],
)
ds_test = create_dataset(
["/data/test_p_patches.tif"],
["/data/test_xs_patches.tif"],
["/data/test_labels_patches.tif"],
)
train(params, ds_train, ds_valid, ds_test)
Dig deeper 🚀¶
Image summary¶
Use TensorFlow summaries to monitor additional data during the training process.
Question
- Create a summary writer in the same directory as the TensorBoard logs,
- Add a new summary showing the softmax output image for the class id 1
Custom callback¶
One can implement a custom callback using the keras callback API.
Question
- Implement a custom callback that wraps the softmax output image summary.
The callback must generate the image summary after each epoch (override
on_epoch_end()). - Test your callback and check that the image summaries are displayed in TensorBoard.