Usando lo stesso dataset e le stesse maschere di training del precedente post ho provato la rete DeepLab V3+
pip install torch torchvision albumentations opencv-python matplotlib
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
# -------- CONFIG --------
IMAGE_DIR = "data/images"
MASK_DIR = "data/masks"
NUM_CLASSES = 4
EPOCHS = 20
BATCH_SIZE = 4
IMG_SIZE = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------
# Map pixel values to class indices
PIXEL_TO_CLASS = {
0: 0, # background
85: 1, # sample
170: 2, # joint
255: 3 # strata
}
def convert_mask(mask):
"""Map 8-bit pixel values to class indices (0-3)."""
new_mask = np.zeros_like(mask, dtype=np.uint8)
for pixel_val, class_idx in PIXEL_TO_CLASS.items():
new_mask[mask == pixel_val] = class_idx
return new_mask
class SegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.filenames = [f for f in os.listdir(image_dir) if f.endswith(".png")]
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.filenames[idx])
mask_path = os.path.join(self.mask_dir, self.filenames[idx])
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = convert_mask(mask)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented["image"]
mask = augmented["mask"]
return image, mask.long()
def get_transforms():
return A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
def train():
# Load dataset
transform = get_transforms()
dataset = SegmentationDataset(IMAGE_DIR, MASK_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Load model
model = models.segmentation.deeplabv3_resnet50(weights=None, num_classes=NUM_CLASSES)
model = model.to(DEVICE)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# Training loop
for epoch in range(EPOCHS):
model.train()
total_loss = 0
for images, masks in dataloader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
outputs = model(images)["out"]
loss = criterion(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss:.4f}")
# Save model
torch.save(model.state_dict(), "deeplabv3_model.pth")
# Visualize prediction on 1 sample
visualize_prediction(model, dataset)
def visualize_prediction(model, dataset):
model.eval()
image, mask = dataset[0]
with torch.no_grad():
input_tensor = image.unsqueeze(0).to(DEVICE)
output = model(input_tensor)["out"]
pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
# Convert back to pixel values
class_to_pixel = {v: k for k, v in PIXEL_TO_CLASS.items()}
pred_mask_rgb = np.vectorize(class_to_pixel.get)(pred)
true_mask_rgb = np.vectorize(class_to_pixel.get)(mask.numpy())
plt.figure(figsize=(12, 5))
plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(image.permute(1, 2, 0).cpu())
plt.subplot(1, 3, 2)
plt.title("Prediction")
plt.imshow(pred_mask_rgb, cmap="jet", vmin=0, vmax=255)
plt.subplot(1, 3, 3)
plt.title("Ground Truth")
plt.imshow(true_mask_rgb, cmap="jet", vmin=0, vmax=255)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
train()

Nessun commento:
Posta un commento