mercoledì 6 agosto 2025

DeepLab V3+ su carote di sondaggio

Usando lo stesso dataset e le stesse maschere  di training del precedente post ho provato la rete DeepLab V3+ 



Computazionalmente DeepLab e' risultata piu' impegnativa di Unet ma i risultati, come si vede dall'immagine di confronto soprastante, sono decisamente migliori in particolare per quanto riguarda i falsi positivi

pip install torch torchvision albumentations opencv-python matplotlib


import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt

# -------- CONFIG --------
IMAGE_DIR = "data/images"
MASK_DIR = "data/masks"
NUM_CLASSES = 4
EPOCHS = 20
BATCH_SIZE = 4
IMG_SIZE = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------

# Map pixel values to class indices
PIXEL_TO_CLASS = {
0: 0, # background
85: 1, # sample
170: 2, # joint
255: 3 # strata
}

def convert_mask(mask):
"""Map 8-bit pixel values to class indices (0-3)."""
new_mask = np.zeros_like(mask, dtype=np.uint8)
for pixel_val, class_idx in PIXEL_TO_CLASS.items():
new_mask[mask == pixel_val] = class_idx
return new_mask

class SegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.filenames = [f for f in os.listdir(image_dir) if f.endswith(".png")]

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.filenames[idx])
mask_path = os.path.join(self.mask_dir, self.filenames[idx])

image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = convert_mask(mask)

if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented["image"]
mask = augmented["mask"]

return image, mask.long()

def get_transforms():
return A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])

def train():
# Load dataset
transform = get_transforms()
dataset = SegmentationDataset(IMAGE_DIR, MASK_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Load model
model = models.segmentation.deeplabv3_resnet50(weights=None, num_classes=NUM_CLASSES)
model = model.to(DEVICE)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
for epoch in range(EPOCHS):
model.train()
total_loss = 0
for images, masks in dataloader:
images, masks = images.to(DEVICE), masks.to(DEVICE)

outputs = model(images)["out"]
loss = criterion(outputs, masks)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss:.4f}")

# Save model
torch.save(model.state_dict(), "deeplabv3_model.pth")

# Visualize prediction on 1 sample
visualize_prediction(model, dataset)

def visualize_prediction(model, dataset):
model.eval()
image, mask = dataset[0]
with torch.no_grad():
input_tensor = image.unsqueeze(0).to(DEVICE)
output = model(input_tensor)["out"]
pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

# Convert back to pixel values
class_to_pixel = {v: k for k, v in PIXEL_TO_CLASS.items()}
pred_mask_rgb = np.vectorize(class_to_pixel.get)(pred)
true_mask_rgb = np.vectorize(class_to_pixel.get)(mask.numpy())

plt.figure(figsize=(12, 5))
plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(image.permute(1, 2, 0).cpu())

plt.subplot(1, 3, 2)
plt.title("Prediction")
plt.imshow(pred_mask_rgb, cmap="jet", vmin=0, vmax=255)

plt.subplot(1, 3, 3)
plt.title("Ground Truth")
plt.imshow(true_mask_rgb, cmap="jet", vmin=0, vmax=255)

plt.tight_layout()
plt.show()

if __name__ == "__main__":
train()



Nessun commento:

Posta un commento

Chiavetta ALIA

Sono maledettamente distratto e stavo cercando di vedere se riesco a replicare la chiavetta dei cassonetti di Firenze senza dover per forza ...