Leggendo ho trovato che il piu' recente aggiornamento per la segmentazione di immagini e' costituita dalla rete Swin-Unet (derivante dalla fusione di Swin ed Unet)
per fare funzionare la rete per prima cosa si deve clone all'interno del folder del progetto questo repository https://github.com/HuCaoFighting/Swin-Unet
Il tempo di calcolo e' nettamente inferiore rispetto a DeepLabimport os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import sys
sys.path.append('./Swin-Unet')
from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
import timm
# -------- CONFIG --------
IMAGE_DIR = "data/images"
MASK_DIR = "data/masks"
NUM_CLASSES = 4
EPOCHS = 30
BATCH_SIZE = 8
IMG_SIZE = 224 # Corrected for Swin-Unet compatibility
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------
# Map pixel values to class indices
PIXEL_TO_CLASS = {
0: 0, # background
85: 1, # sample
170: 2, # joint
255: 3 # strata
}
def convert_mask(mask):
"""Map 8-bit pixel values to class indices (0-3)."""
new_mask = np.zeros_like(mask, dtype=np.uint8)
for pixel_val, class_idx in PIXEL_TO_CLASS.items():
new_mask[mask == pixel_val] = class_idx
return new_mask
class SegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.filenames = [f for f in os.listdir(image_dir) if f.endswith(".png")]
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.filenames[idx])
mask_path = os.path.join(self.mask_dir, self.filenames[idx])
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = convert_mask(mask)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented["image"]
mask = augmented["mask"]
return image, mask.long()
def get_class_counts(self):
"""Calculates pixel counts for each class to determine weights."""
class_counts = np.zeros(NUM_CLASSES)
for filename in self.filenames:
mask_path = os.path.join(self.mask_dir, filename)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = convert_mask(mask)
for i in range(NUM_CLASSES):
class_counts[i] += np.sum(mask == i)
return class_counts
def get_transforms():
return A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
class DiceLoss(nn.Module):
def __init__(self, num_classes, ignore_index=None):
super().__init__()
self.num_classes = num_classes
self.ignore_index = ignore_index
def forward(self, pred, target):
smooth = 1e-5
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
union = iflat.sum() + tflat.sum()
return 1.0 - ((2.0 * intersection + smooth) / (union + smooth))
class MixedLoss(nn.Module):
def __init__(self, num_classes, weights, dice_weight=0.5):
super().__init__()
self.cross_entropy_loss = nn.CrossEntropyLoss(weight=weights)
self.dice_loss = DiceLoss(num_classes)
self.dice_weight = dice_weight
def forward(self, pred, target):
# Convert target to one-hot for DiceLoss
target_one_hot = nn.functional.one_hot(target, num_classes=NUM_CLASSES).permute(0, 3, 1, 2).float()
ce_loss = self.cross_entropy_loss(pred, target)
dice_loss = self.dice_loss(torch.softmax(pred, dim=1), target_one_hot)
return ce_loss * (1 - self.dice_weight) + dice_loss * self.dice_weight
def train():
transform = get_transforms()
dataset = SegmentationDataset(IMAGE_DIR, MASK_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Calculate class weights for Cross-Entropy Loss
class_counts = dataset.get_class_counts()
total_pixels = class_counts.sum()
class_weights = 1.0 / (class_counts / total_pixels + 1e-5) # Add a small epsilon to prevent division by zero
class_weights = class_weights / class_weights.sum() # Normalize to sum to 1
class_weights = torch.from_numpy(class_weights).float().to(DEVICE)
print(f"Class weights: {class_weights}")
model = SwinTransformerSys(img_size=IMG_SIZE, num_classes=NUM_CLASSES)
model = model.to(DEVICE)
# Use the MixedLoss
criterion = MixedLoss(num_classes=NUM_CLASSES, weights=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
for epoch in range(EPOCHS):
model.train()
total_loss = 0
for images, masks in dataloader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss / len(dataloader):.4f}")
torch.save(model.state_dict(), "swinunet_model.pth")
visualize_prediction(model, dataset)
def visualize_prediction(model, dataset):
model.eval()
image, mask = dataset[0]
with torch.no_grad():
input_tensor = image.unsqueeze(0).to(DEVICE)
output = model(input_tensor)
pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
# Reverse normalization for visualization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
unnormalized_img = image.permute(1, 2, 0).cpu().numpy() * std + mean
unnormalized_img = np.clip(unnormalized_img, 0, 1)
class_to_pixel = {v: k for k, v in PIXEL_TO_CLASS.items()}
pred_mask_rgb = np.vectorize(class_to_pixel.get)(pred)
true_mask_rgb = np.vectorize(class_to_pixel.get)(mask.numpy())
plt.figure(figsize=(12, 5))
plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(unnormalized_img)
plt.subplot(1, 3, 2)
plt.title("Prediction")
plt.imshow(pred_mask_rgb, cmap="jet", vmin=0, vmax=255)
plt.subplot(1, 3, 3)
plt.title("Ground Truth")
plt.imshow(true_mask_rgb, cmap="jet", vmin=0, vmax=255)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
train()

Nessun commento:
Posta un commento