# -*- coding: utf-8 -*-
# ==============================================================================
# UNet Image Segmentation Notebook (Final Version with Custom Weighted Loss)
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight # Still used for calculating weights initially
import matplotlib.pyplot as plt
import cv2 # Required for display_individual_class_segmentation overlay
# ==============================================================================
# STEP 1: Mount Google Drive
#from google.colab import drive
#drive.mount('/content/drive')
# ==============================================================================
# STEP 2: Define Data Paths and Parameters
# ... (your existing code) ...
DATA_PATH = './data'
IMAGE_DIR = os.path.join(DATA_PATH, 'images')
MASK_DIR = os.path.join(DATA_PATH, 'masks')
IMG_SIZE = (256, 256)
NUM_CLASSES = 4
PIXEL_VALUE_MAP = {0: 0, 85: 1, 170: 2, 255: 3}
REVERSE_VALUE_MAP = {0: 0, 1: 85, 2: 170, 3: 255}
# ==============================================================================
# STEP 3: Load and Preprocess the Dataset
# ... (your existing load_data function, which now returns integer masks) ...
def load_data(image_dir, mask_dir, img_size, pixel_map):
"""
Loads images and masks from directories, resizes them, and processes the masks.
Masks are returned as (H, W) integer labels.
"""
image_filenames = sorted(os.listdir(image_dir))
mask_filenames = sorted(os.listdir(mask_dir))
images = []
masks_int_labels = [] # Renamed for clarity: masks will be integer labels
assert len(image_filenames) == len(mask_filenames), "Number of images and masks do not match!"
print(f"Loading {len(image_filenames)} images...")
for i in range(len(image_filenames)):
image_path = os.path.join(image_dir, image_filenames[i])
mask_path = os.path.join(mask_dir, mask_filenames[i])
img = load_img(image_path, target_size=img_size)
img_array = img_to_array(img) / 255.0
images.append(img_array)
mask = load_img(mask_path, target_size=img_size, color_mode='grayscale')
mask_array = img_to_array(mask).astype(np.int32).squeeze() # (H, W)
for old_val, new_val in pixel_map.items():
mask_array[mask_array == old_val] = new_val
masks_int_labels.append(mask_array) # Store as (H, W) integer masks
return np.array(images), np.array(masks_int_labels)
# Load the data
images, masks_int = load_data(IMAGE_DIR, MASK_DIR, IMG_SIZE, PIXEL_VALUE_MAP)
# Split the dataset into training and validation sets
# masks_int is (N, H, W) integer labels
X_train, X_val, y_train_int, y_val_int = train_test_split(
images, masks_int, test_size=0.2, random_state=42
)
print(f"Training images shape: {X_train.shape}")
print(f"Validation masks shape: {y_val_int.shape}") # This is (N, H, W)
# ==============================================================================
# STEP 4: Calculate Class Weights
# This step computes weights inversely proportional to class frequencies
# to be used in the custom loss function.
# ==============================================================================
print("\nCalculating class weights...")
# y_train_int is already in (N, H, W) integer format
y_train_indices_flat = y_train_int.flatten()
all_possible_classes = np.arange(NUM_CLASSES)
class_weights_array = class_weight.compute_class_weight(
class_weight='balanced',
classes=all_possible_classes,
y=y_train_indices_flat
)
# Convert the array to a TensorFlow tensor for use in the custom loss
# Ensure weights are float32
CLASS_WEIGHTS_TENSOR = tf.constant(class_weights_array, dtype=tf.float32)
print(f"Calculated Class Weights (as array): {class_weights_array}")
print(f"Class Weights (as TensorFlow tensor): {CLASS_WEIGHTS_TENSOR}")
# ==============================================================================
# Define Custom Weighted Categorical Crossentropy Loss
# This replaces the simple 'categorical_crossentropy' loss.
# ==============================================================================
def weighted_categorical_crossentropy_loss(y_true, y_pred):
# y_true is expected to be one-hot encoded (Batch, H, W, NUM_CLASSES)
# y_pred is expected to be softmax probabilities (Batch, H, W, NUM_CLASSES)
# Convert y_true from one-hot to sparse integer labels (Batch, H, W)
y_true_labels = tf.argmax(y_true, axis=-1)
# Flatten y_true_labels to get a 1D tensor of class indices (Batch*H*W,)
y_true_labels_flat = tf.reshape(y_true_labels, [-1])
# Flatten y_pred to (Batch*H*W, NUM_CLASSES) for CCE
y_pred_flat = tf.reshape(y_pred, [-1, NUM_CLASSES])
# Get the weight for each pixel based on its true class
pixel_weights = tf.gather(CLASS_WEIGHTS_TENSOR, y_true_labels_flat)
# Calculate standard categorical crossentropy (per pixel)
# We need the one-hot encoded y_true, flattened, to match y_pred_flat
y_true_flat_one_hot = tf.reshape(y_true, [-1, NUM_CLASSES])
cce = tf.keras.losses.CategoricalCrossentropy(
from_logits=False, # y_pred is probabilities (softmax output)
reduction=tf.keras.losses.Reduction.NONE # Keep loss per pixel
)
unweighted_loss = cce(y_true_flat_one_hot, y_pred_flat) # Corrected line!
# Apply pixel-wise weights
weighted_loss = unweighted_loss * pixel_weights
# Return the mean of the weighted loss over all pixels in the batch
return tf.reduce_mean(weighted_loss)
# ==============================================================================
# Convert masks to one-hot for model training (after class weights are calculated)
# The model's loss function (our custom one) expects one-hot encoded masks.
# ==============================================================================
print("\nConverting masks to one-hot encoding for model training...")
y_train_one_hot = tf.one_hot(y_train_int, depth=NUM_CLASSES)
y_val_one_hot = tf.one_hot(y_val_int, depth=NUM_CLASSES)
print(f"One-hot encoded training masks shape: {y_train_one_hot.shape}")
print(f"One-hot encoded validation masks shape: {y_val_one_hot.shape}")
# ==============================================================================
# STEP 5: Build the UNet Model
# ... (your unet_model function remains unchanged) ...
# ==============================================================================
def unet_model(input_size=(256, 256, 3), num_classes=NUM_CLASSES):
"""
Defines the UNet model architecture.
"""
inputs = layers.Input(input_size)
# Encoder
conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)
drop4 = layers.Dropout(0.5)(conv4)
pool4 = layers.MaxPooling2D(pool_size=(2, 2))(drop4)
# Bottleneck
conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same')(pool4)
conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same')(conv5)
drop5 = layers.Dropout(0.5)(conv5)
# Decoder
up6 = layers.Conv2D(512, 2, activation='relu', padding='same')(layers.UpSampling2D(size=(2, 2))(drop5))
merge6 = layers.concatenate([drop4, up6], axis=3)
conv6 = layers.Conv2D(512, 3, activation='relu', padding='same')(merge6)
conv6 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv6)
up7 = layers.Conv2D(256, 2, activation='relu', padding='same')(layers.UpSampling2D(size=(2, 2))(conv6))
merge7 = layers.concatenate([conv3, up7], axis=3)
conv7 = layers.Conv2D(256, 3, activation='relu', padding='same')(merge7)
conv7 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv7)
up8 = layers.Conv2D(128, 2, activation='relu', padding='same')(layers.UpSampling2D(size=(2, 2))(conv7))
merge8 = layers.concatenate([conv2, up8], axis=3)
conv8 = layers.Conv2D(128, 3, activation='relu', padding='same')(merge8)
conv8 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv8)
up9 = layers.Conv2D(64, 2, activation='relu', padding='same')(layers.UpSampling2D(size=(2, 2))(conv8))
merge9 = layers.concatenate([conv1, up9], axis=3)
conv9 = layers.Conv2D(64, 3, activation='relu', padding='same')(merge9)
conv9 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv9)
conv10 = layers.Conv2D(num_classes, 1, activation='softmax')(conv9)
model = Model(inputs=inputs, outputs=conv10)
return model
model = unet_model(input_size=(*IMG_SIZE, 3))
model.summary()
# ==============================================================================
# STEP 6: Compile and Train the Model (with Custom Weighted Loss)
# The model is compiled with an optimizer, our custom weighted loss, and metrics.
# ==============================================================================
model.compile(
optimizer='adam',
loss=weighted_categorical_crossentropy_loss, # Use our custom weighted loss
metrics=['accuracy']
)
print("\nTraining the model with custom weighted loss...")
history = model.fit(
X_train, y_train_one_hot, # Still use one-hot encoded targets
validation_data=(X_val, y_val_one_hot), # Still use one-hot encoded targets
epochs=20, # Consider increasing this if the model is still learning
batch_size=4,
# REMOVED: class_weight parameter is no longer needed here
)
# ==============================================================================
# STEP 7: Save the Model
# ... (your model saving code remains unchanged, updated filename) ...
# ==============================================================================
MODEL_SAVE_DIR = os.path.join(DATA_PATH, 'models')
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
model_filename = 'unet_segmentation_model_custom_weighted_loss.keras' # New filename
model_save_path = os.path.join(MODEL_SAVE_DIR, model_filename)
print(f"\nSaving model to: {model_save_path}")
try:
model.save(model_save_path)
print("Model saved successfully!")
except Exception as e:
print(f"Error saving model: {e}")
history_filename = 'training_history_custom_weighted_loss.npy'
history_save_path = os.path.join(MODEL_SAVE_DIR, history_filename)
np.save(history_save_path, history.history)
print(f"Training history saved to: {history_save_path}")
# ==============================================================================
# STEP 8: Visualize the Results (Overall and Per-Class)
# ... (your display functions remain unchanged, use y_val_one_hot) ...
# ==============================================================================
def display_results(image, true_mask, pred_mask, reverse_map):
"""
Displays the original image, ground truth mask, and the overall predicted mask.
"""
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
true_mask_indices = np.argmax(true_mask, axis=-1)
pred_mask_indices = np.argmax(pred_mask, axis=-1)
true_mask_vis = np.vectorize(reverse_map.get)(true_mask_indices)
pred_mask_vis = np.vectorize(reverse_map.get)(pred_mask_indices)
axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')
axes[1].imshow(true_mask_vis, cmap='gray', vmin=0, vmax=255)
axes[1].set_title('Ground Truth Mask')
axes[1].axis('off')
axes[2].imshow(pred_mask_vis, cmap='gray', vmin=0, vmax=255)
axes[2].set_title('Predicted Mask (All Classes)')
axes[2].axis('off')
plt.tight_layout()
plt.show()
def display_individual_class_segmentation(image, pred_mask, pixel_value_map, class_names=None):
"""
Displays the segmentation for each individual class on separate images,
overlaying the predicted class onto the original image.
"""
pred_mask_indices = np.argmax(pred_mask, axis=-1)
if class_names is None:
sorted_pixel_values = sorted(pixel_value_map.items(), key=lambda item: item[1])
class_names = [f"Class {index} ({original_val})" for original_val, index in sorted_pixel_values]
num_classes = pred_mask.shape[-1]
fig, axes = plt.subplots(1, num_classes, figsize=(num_classes * 5, 5))
if num_classes == 1:
axes = [axes]
color_map_bgr = {
0: [0, 0, 0], # Black for Background
1: [0, 255, 0], # Green for Sample (85)
2: [0, 0, 255], # Red for Joint (170)
3: [255, 0, 0] # Blue for Strata (255)
}
for i in range(num_classes):
class_specific_mask = (pred_mask_indices == i).astype(np.uint8)
color = color_map_bgr.get(i, [255, 255, 255])
colored_mask = np.zeros_like(image)
for c_idx in range(3):
if c_idx == 0: colored_mask[:, :, c_idx] = class_specific_mask * (color[2] / 255.0)
if c_idx == 1: colored_mask[:, :, c_idx] = class_specific_mask * (color[1] / 255.0)
if c_idx == 2: colored_mask[:, :, c_idx] = class_specific_mask * (color[0] / 255.0)
image_255 = (image * 255).astype(np.uint8)
image_255_bgr = cv2.cvtColor(image_255, cv2.COLOR_RGB2BGR)
overlay_bgr = np.zeros_like(image_255_bgr)
for c_chan in range(3):
overlay_bgr[:, :, c_chan] = class_specific_mask * color[c_chan]
alpha = 0.5
blended_image_bgr = cv2.addWeighted(image_255_bgr, 1 - alpha, overlay_bgr, alpha, 0)
blended_image_rgb = cv2.cvtColor(blended_image_bgr, cv2.COLOR_BGR2RGB)
axes[i].imshow(blended_image_rgb)
axes[i].set_title(f'Predicted: {class_names[i]}')
axes[i].axis('off')
plt.tight_layout()
plt.show()
sample_index = np.random.randint(0, len(X_val))
sample_image = X_val[sample_index]
sample_true_mask = y_val_one_hot[sample_index] # Use one-hot for true mask display
sample_pred_mask = model.predict(np.expand_dims(sample_image, 0))[0]
display_results(sample_image, sample_true_mask, sample_pred_mask, REVERSE_VALUE_MAP)
custom_class_names = ["Background (0)", "Sample (85)", "Joint (170)", "Strata (255)"]
display_individual_class_segmentation(sample_image, sample_pred_mask, PIXEL_VALUE_MAP, custom_class_names)
# OPTIONAL: Debugging Check
print("\n--- Debugging Predicted Mask (After Training with Custom Weighted Loss) ---")
print("Shape of sample_pred_mask (raw probabilities):", sample_pred_mask.shape)
print("Min/Max/Mean probabilities for each class channel:")
for i in range(NUM_CLASSES):
channel_probs = sample_pred_mask[:, :, i]
print(f" Class {i}: Min={np.min(channel_probs):.4f}, Max={np.max(channel_probs):.4f}, Mean={np.mean(channel_probs):.4f}")
pred_mask_indices = np.argmax(sample_pred_mask, axis=-1)
print("Shape of pred_mask_indices (after argmax):", pred_mask_indices.shape)
print("Unique values in pred_mask_indices:", np.unique(pred_mask_indices))
unique_classes, counts = np.unique(pred_mask_indices, return_counts=True)
total_pixels = pred_mask_indices.size
print("Distribution of predicted classes:")
for u_class, count in zip(unique_classes, counts):
print(f" Class {u_class}: {count} pixels ({count / total_pixels * 100:.2f}%)")
pred_mask_vis = np.vectorize(REVERSE_VALUE_MAP.get)(pred_mask_indices)
print("Unique values in pred_mask_vis (after reverse map):", np.unique(pred_mask_vis))
print("--- End Debugging Predicted Mask ---")
Nessun commento:
Posta un commento