mercoledì 6 agosto 2025

Swin-Unet su carote di sondaggio

 Leggendo ho trovato che il piu' recente aggiornamento per la segmentazione di immagini e' costituita dalla rete Swin-Unet (derivante dalla fusione di Swin ed Unet)


per fare funzionare la rete per prima cosa si deve clone all'interno del folder del progetto questo repository https://github.com/HuCaoFighting/Swin-Unet
Il tempo di calcolo e' nettamente inferiore rispetto a DeepLab


import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt

import sys
sys.path.append('./Swin-Unet')

from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
import timm

# -------- CONFIG --------
IMAGE_DIR = "data/images"
MASK_DIR = "data/masks"
NUM_CLASSES = 4
EPOCHS = 30
BATCH_SIZE = 8
IMG_SIZE = 224 # Corrected for Swin-Unet compatibility
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------

# Map pixel values to class indices
PIXEL_TO_CLASS = {
0: 0, # background
85: 1, # sample
170: 2, # joint
255: 3 # strata
}

def convert_mask(mask):
"""Map 8-bit pixel values to class indices (0-3)."""
new_mask = np.zeros_like(mask, dtype=np.uint8)
for pixel_val, class_idx in PIXEL_TO_CLASS.items():
new_mask[mask == pixel_val] = class_idx
return new_mask

class SegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.filenames = [f for f in os.listdir(image_dir) if f.endswith(".png")]

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.filenames[idx])
mask_path = os.path.join(self.mask_dir, self.filenames[idx])

image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = convert_mask(mask)

if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented["image"]
mask = augmented["mask"]

return image, mask.long()
def get_class_counts(self):
"""Calculates pixel counts for each class to determine weights."""
class_counts = np.zeros(NUM_CLASSES)
for filename in self.filenames:
mask_path = os.path.join(self.mask_dir, filename)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = convert_mask(mask)
for i in range(NUM_CLASSES):
class_counts[i] += np.sum(mask == i)
return class_counts

def get_transforms():
return A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])

class DiceLoss(nn.Module):
def __init__(self, num_classes, ignore_index=None):
super().__init__()
self.num_classes = num_classes
self.ignore_index = ignore_index

def forward(self, pred, target):
smooth = 1e-5
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
union = iflat.sum() + tflat.sum()
return 1.0 - ((2.0 * intersection + smooth) / (union + smooth))

class MixedLoss(nn.Module):
def __init__(self, num_classes, weights, dice_weight=0.5):
super().__init__()
self.cross_entropy_loss = nn.CrossEntropyLoss(weight=weights)
self.dice_loss = DiceLoss(num_classes)
self.dice_weight = dice_weight

def forward(self, pred, target):
# Convert target to one-hot for DiceLoss
target_one_hot = nn.functional.one_hot(target, num_classes=NUM_CLASSES).permute(0, 3, 1, 2).float()
ce_loss = self.cross_entropy_loss(pred, target)
dice_loss = self.dice_loss(torch.softmax(pred, dim=1), target_one_hot)

return ce_loss * (1 - self.dice_weight) + dice_loss * self.dice_weight

def train():
transform = get_transforms()
dataset = SegmentationDataset(IMAGE_DIR, MASK_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Calculate class weights for Cross-Entropy Loss
class_counts = dataset.get_class_counts()
total_pixels = class_counts.sum()
class_weights = 1.0 / (class_counts / total_pixels + 1e-5) # Add a small epsilon to prevent division by zero
class_weights = class_weights / class_weights.sum() # Normalize to sum to 1
class_weights = torch.from_numpy(class_weights).float().to(DEVICE)
print(f"Class weights: {class_weights}")

model = SwinTransformerSys(img_size=IMG_SIZE, num_classes=NUM_CLASSES)
model = model.to(DEVICE)

# Use the MixedLoss
criterion = MixedLoss(num_classes=NUM_CLASSES, weights=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(EPOCHS):
model.train()
total_loss = 0
for images, masks in dataloader:
images, masks = images.to(DEVICE), masks.to(DEVICE)

outputs = model(images)
loss = criterion(outputs, masks)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss / len(dataloader):.4f}")

torch.save(model.state_dict(), "swinunet_model.pth")
visualize_prediction(model, dataset)

def visualize_prediction(model, dataset):
model.eval()
image, mask = dataset[0]
with torch.no_grad():
input_tensor = image.unsqueeze(0).to(DEVICE)
output = model(input_tensor)
pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

# Reverse normalization for visualization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
unnormalized_img = image.permute(1, 2, 0).cpu().numpy() * std + mean
unnormalized_img = np.clip(unnormalized_img, 0, 1)

class_to_pixel = {v: k for k, v in PIXEL_TO_CLASS.items()}
pred_mask_rgb = np.vectorize(class_to_pixel.get)(pred)
true_mask_rgb = np.vectorize(class_to_pixel.get)(mask.numpy())

plt.figure(figsize=(12, 5))
plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(unnormalized_img)

plt.subplot(1, 3, 2)
plt.title("Prediction")
plt.imshow(pred_mask_rgb, cmap="jet", vmin=0, vmax=255)

plt.subplot(1, 3, 3)
plt.title("Ground Truth")
plt.imshow(true_mask_rgb, cmap="jet", vmin=0, vmax=255)

plt.tight_layout()
plt.show()

if __name__ == "__main__":
train()





Nessun commento:

Posta un commento

Kernel Panic QrCode

 In tanti anni ho visto qualche kernel panic, ma in questo formato non mi era mai successo    la cosa curiosa che al riavvio successivo ness...