"""
train_deeplabv3p.py
Requirements:
- tensorflow >= 2.8 (tested on TF 2.10+)
- matplotlib
- opencv-python (cv2) optional if you want to preview images locally
Dataset layout expected:
fin/images/<name>.png (RGB)
fin/mask/<name>.png (grayscale, 0 for background, 127 for class)
Usage:
python train_deeplabv3p.py
"""
import os
import random
import glob
import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
# ----------------------
# Config - tweak here
# ----------------------
IM_SIZE = (256, 256) # input size for training (height, width)
BATCH_SIZE = 8
EPOCHS = 30
AUTOTUNE = tf.data.AUTOTUNE
DATA_DIR = "./dataset"
IMAGE_DIR = os.path.join(DATA_DIR, "images")
MASK_DIR = os.path.join(DATA_DIR, "masks")
MODEL_SAVE = "coredrill_deeplabv3p.h5"
VAL_SPLIT = 0.15
SEED = 42
LEARNING_RATE = 1e-4
# ----------------------
class BinaryMeanIoU(tf.keras.metrics.MeanIoU):
def __init__(self, name="iou"):
super().__init__(num_classes=2, name=name)
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.cast(y_pred > 0.5, tf.int32)
y_true = tf.cast(y_true, tf.int32)
return super().update_state(y_true, y_pred, sample_weight)
# ----------------------
# Utility: DeepLabV3+ model (MobileNetV2 backbone)
# ----------------------
def SepConv_BN(x, filters, prefix, stride=1, kernel_size=3, rate=1):
x = layers.SeparableConv2D(filters, kernel_size=kernel_size, strides=stride,
padding='same', dilation_rate=rate,
use_bias=False, name=prefix + '_sepconv')(x)
x = layers.BatchNormalization(name=prefix + '_bn')(x)
x = layers.ReLU(name=prefix + '_relu')(x)
return x
def ASPP(x, out_channels=256):
# Atrous Spatial Pyramid Pooling
b0 = layers.Conv2D(out_channels, 1, padding='same', use_bias=False)(x)
b0 = layers.BatchNormalization()(b0)
b0 = layers.ReLU()(b0)
b1 = layers.SeparableConv2D(out_channels, 3, padding='same', dilation_rate=6, use_bias=False)(x)
b1 = layers.BatchNormalization()(b1)
b1 = layers.ReLU()(b1)
b2 = layers.SeparableConv2D(out_channels, 3, padding='same', dilation_rate=12, use_bias=False)(x)
b2 = layers.BatchNormalization()(b2)
b2 = layers.ReLU()(b2)
b3 = layers.SeparableConv2D(out_channels, 3, padding='same', dilation_rate=18, use_bias=False)(x)
b3 = layers.BatchNormalization()(b3)
b3 = layers.ReLU()(b3)
# Image pooling branch
b4 = layers.GlobalAveragePooling2D()(x)
b4 = layers.Reshape((1, 1, -1))(b4)
b4 = layers.Conv2D(out_channels, 1, padding='same', use_bias=False)(b4)
b4 = layers.BatchNormalization()(b4)
b4 = layers.ReLU()(b4)
# Instead of using tf.shape(x), we upsample by a fixed scale factor
b4 = layers.UpSampling2D(size=(x.shape[1], x.shape[2]), interpolation='bilinear')(b4)
# Concatenate and project
x = layers.Concatenate()([b0, b1, b2, b3, b4])
x = layers.Conv2D(out_channels, 1, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
return x
def DeepLabV3Plus(input_shape=(256,256,3), num_classes=1, backbone='mobilenetv2'):
# Encoder (MobileNetV2)
base_model = tf.keras.applications.MobileNetV2(input_shape=input_shape, include_top=False, weights='imagenet')
# Extract feature maps
# low-level feature for decoder
low_level = base_model.get_layer('block_3_expand_relu').output # example low-level
# high-level feature for ASPP
high_level = base_model.get_layer('block_13_expand_relu').output
# ASPP on high-level features
x = ASPP(high_level, out_channels=256)
x = layers.UpSampling2D(size=(4,4), interpolation='bilinear')(x) # scale to match low-level approx
# Process low-level
low = layers.Conv2D(48, 1, padding='same', use_bias=False)(low_level)
low = layers.BatchNormalization()(low)
low = layers.ReLU()(low)
# Concatenate
x = layers.Concatenate()([x, low])
x = SepConv_BN(x, 256, 'decoder_separable_conv0')
x = SepConv_BN(x, 256, 'decoder_separable_conv1')
# Upsample to input size
x = layers.UpSampling2D(size=(4,4), interpolation='bilinear')(x)
# Final conv
if num_classes == 1:
activation = 'sigmoid'
out_filters = 1
else:
activation = 'softmax'
out_filters = num_classes
x = layers.Conv2D(out_filters, 1, padding='same')(x)
x = layers.Activation(activation)(x)
model = tf.keras.Model(inputs=base_model.input, outputs=x)
return model
# ----------------------
# Data pipeline
# ----------------------
def list_pairs(image_dir, mask_dir):
# match by filename (without extension)
images = sorted(glob.glob(os.path.join(image_dir, "*")))
image_map = {os.path.splitext(os.path.basename(p))[0]: p for p in images}
masks = sorted(glob.glob(os.path.join(mask_dir, "*")))
mask_map = {os.path.splitext(os.path.basename(p))[0]: p for p in masks}
common = sorted(set(image_map.keys()).intersection(mask_map.keys()))
pairs = [(image_map[k], mask_map[k]) for k in common]
return pairs
def decode_image(path, target_size=IM_SIZE):
img = tf.io.read_file(path)
img = tf.image.decode_image(img, channels=3)
img.set_shape([None, None, 3])
img = tf.image.resize(img, target_size)
img = tf.cast(img, tf.float32) / 255.0
return img
def decode_mask(path, target_size=IM_SIZE):
m = tf.io.read_file(path)
m = tf.image.decode_image(m, channels=1) # single-channel if possible
m.set_shape([None, None, 1])
m = tf.image.resize(m, target_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
m = tf.cast(m, tf.float32)
# Normalize mask: works for either 0/127 or 0/255 style masks
# Any value >64 becomes 1, otherwise 0
m = tf.where(m > 64.0, 1.0, 0.0)
return m
def load_pair(image_path, mask_path):
image = decode_image(image_path)
mask = decode_mask(mask_path)
return image, mask
def augment(image, mask):
# simple augmentation: random flip and random brightness
if tf.random.uniform(()) > 0.5:
image = tf.image.flip_left_right(image)
mask = tf.image.flip_left_right(mask)
if tf.random.uniform(()) > 0.5:
image = tf.image.flip_up_down(image)
mask = tf.image.flip_up_down(mask)
if tf.random.uniform(()) > 0.5:
image = tf.image.random_brightness(image, max_delta=0.1)
return image, mask
def make_datasets(pairs, batch_size=BATCH_SIZE, val_split=VAL_SPLIT):
random.seed(SEED)
random.shuffle(pairs)
n = len(pairs)
n_val = max(1, int(n * val_split))
val_pairs = pairs[:n_val]
train_pairs = pairs[n_val:]
def gen(pairs_list):
for img_p, m_p in pairs_list:
yield img_p, m_p
train_ds = tf.data.Dataset.from_generator(lambda: gen(train_pairs), output_types=(tf.string, tf.string))
val_ds = tf.data.Dataset.from_generator(lambda: gen(val_pairs), output_types=(tf.string, tf.string))
train_ds = (train_ds
.map(lambda i, m: tf.py_function(load_pair, [i, m], [tf.float32, tf.float32]),
num_parallel_calls=AUTOTUNE)
.map(lambda i, m: (tf.ensure_shape(i, [*IM_SIZE, 3]), tf.ensure_shape(m, [*IM_SIZE, 1])),
num_parallel_calls=AUTOTUNE)
.map(lambda i, m: augment(i, m), num_parallel_calls=AUTOTUNE)
.shuffle(256)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
val_ds = (val_ds
.map(lambda i, m: tf.py_function(load_pair, [i, m], [tf.float32, tf.float32]),
num_parallel_calls=AUTOTUNE)
.map(lambda i, m: (tf.ensure_shape(i, [*IM_SIZE, 3]), tf.ensure_shape(m, [*IM_SIZE, 1])),
num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
return train_ds, val_ds, train_pairs, val_pairs
# ----------------------
# Metrics and Loss
# ----------------------
def dice_coef(y_true, y_pred, smooth=1e-6):
y_true_f = tf.reshape(y_true, [-1])
y_pred_f = tf.reshape(y_pred, [-1])
intersection = tf.reduce_sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
def dice_loss(y_true, y_pred):
return 1.0 - dice_coef(y_true, y_pred)
def bce_dice_loss(y_true, y_pred):
bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
return bce + dice_loss(y_true, y_pred)
# ----------------------
# Training routine
# ----------------------
def main():
pairs = list_pairs(IMAGE_DIR, MASK_DIR)
if len(pairs) == 0:
raise RuntimeError(f"No matching image/mask pairs found in {IMAGE_DIR} and {MASK_DIR}.")
print(f"Found {len(pairs)} pairs.")
train_ds, val_ds, train_pairs, val_pairs = make_datasets(pairs)
model = DeepLabV3Plus(input_shape=(*IM_SIZE, 3), num_classes=1)
model.summary()
# compile
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
loss=bce_dice_loss,
#metrics=[tf.keras.metrics.BinaryAccuracy(name='accuracy'),
# tf.keras.metrics.MeanIoU(num_classes=2, name='iou'),
# dice_coef])
metrics=[tf.keras.metrics.BinaryAccuracy(name='accuracy'),
BinaryMeanIoU(),
dice_coef])
# Callbacks
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(MODEL_SAVE, save_best_only=True, monitor='val_loss')
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, verbose=1)
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)
history = model.fit(train_ds,
epochs=EPOCHS,
validation_data=val_ds,
callbacks=[checkpoint_cb, reduce_lr, early_stop])
# Save final model
model.save(MODEL_SAVE)
print(f"Model saved to {MODEL_SAVE}")
# Visual comparison on a few validation samples
visualize_predictions(model, val_pairs, n=6)
def visualize_predictions(model, val_pairs, n=6):
# pick up to n validation examples randomly
samples = random.sample(val_pairs, min(n, len(val_pairs)))
fig_rows = len(samples)
plt.figure(figsize=(10, 4 * fig_rows))
for i, (img_p, mask_p) in enumerate(samples):
img = tf.io.read_file(img_p)
img = tf.image.decode_image(img, channels=3)
img = tf.image.resize(img, IM_SIZE)
img = tf.cast(img, tf.float32) / 255.0
mask = tf.io.read_file(mask_p)
mask = tf.image.decode_image(mask, channels=1)
mask = tf.image.resize(mask, IM_SIZE, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
mask = tf.cast(mask, tf.float32)
mask = (mask > 64.0).numpy().astype(np.uint8).squeeze()
# Predict
inp = tf.expand_dims(img, 0)
pred = model.predict(inp)[0]
pred_mask = (pred[..., 0] > 0.5).astype(np.uint8)
ax = plt.subplot(fig_rows, 3, i*3 + 1)
plt.imshow(img.numpy())
plt.title("Image")
plt.axis('off')
ax = plt.subplot(fig_rows, 3, i*3 + 2)
plt.imshow(mask, cmap='gray')
plt.title("Ground Truth")
plt.axis('off')
ax = plt.subplot(fig_rows, 3, i*3 + 3)
plt.imshow(pred_mask, cmap='gray')
plt.title("Prediction")
plt.axis('off')
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()