Sulla base di quanto letto sul dataset SpectralWaste e codice allegato ho voluto provare SegFormer su due dataset pubblici e classici per testare algoritmi come Indian Pines e Pavia
Indian Pines ha un risoluzione di 220 bande con una risoluzione a terra di 30 m e 145x145 pixel. E' quindi un dataset molto piccolo che necessita ed data augmentation e inoltre e' anche sbilanciato in quanto la classe Background e' numericamente molto piu' numerosa delle altre classi
Pavia ha una risoluzione spettrale di 103 bande con una risoluzione a terra di 1.3 m pixel e 610x340 pixel
I dataset possono essere scaricati da qui
Pavia



import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import random
from scipy.io import loadmat
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerForSemanticSegmentation
from sklearn.metrics import cohen_kappa_score, accuracy_score, classification_report, confusion_matrix
import seaborn as sns
# --- CONFIGURATION FOR INDIAN PINES ---
CONFIG = {
"dataset": "IndianPines",
"data_path": "Indian_pines_corrected.mat",
"gt_path": "Indian_pines_gt.mat",
"in_channels": 200, # Standard for corrected Indian Pines
"num_classes": 17, # 16 classes + 1 background (0)
"window_size": 32, # Smaller window for 145x145 image
"stride": 4, # Smaller stride to increase patch count
"train_ratio": 0.2, # 20% training, 80% testing
"batch_size": 16,
"epochs": 80, # Increased epochs for convergence
"lr": 1e-4
}
# 1. DATASET WITH AUTOMATIC KEY DETECTION
class HSIDataset(Dataset):
def __init__(self, cfg, is_train=True, augment=True):
raw_data = loadmat(cfg["data_path"])
raw_gt = loadmat(cfg["gt_path"])
# Auto-detect keys (ignores metadata keys starting with __)
data_key = [k for k in raw_data.keys() if not k.startswith('__')][0]
gt_key = [k for k in raw_gt.keys() if not k.startswith('__')][0]
data = raw_data[data_key].astype(np.float32)
gt = raw_gt[gt_key].astype(np.int64)
# Normalization (Min-Max)
data = (data - np.min(data)) / (np.max(data) - np.min(data))
self.data = np.transpose(data, (2, 0, 1)) # [C, H, W]
# Create Train/Test Split logic
labeled_indices = np.where(gt > 0)
num_labeled = len(labeled_indices[0])
indices = np.arange(num_labeled)
np.random.seed(42)
np.random.shuffle(indices)
train_count = int(num_labeled * cfg["train_ratio"])
train_idx = indices[:train_count]
test_idx = indices[train_count:]
split_gt = np.zeros_like(gt)
if is_train:
split_gt[labeled_indices[0][train_idx], labeled_indices[1][train_idx]] = gt[labeled_indices[0][train_idx], labeled_indices[1][train_idx]]
else:
split_gt[labeled_indices[0][test_idx], labeled_indices[1][test_idx]] = gt[labeled_indices[0][test_idx], labeled_indices[1][test_idx]]
self.gt = split_gt
self.augment = augment and is_train
# Patch Generation
self.patches, self.labels = [], []
c, h, w = self.data.shape
for i in range(0, h - cfg["window_size"] + 1, cfg["stride"]):
for j in range(0, w - cfg["window_size"] + 1, cfg["stride"]):
patch_gt = self.gt[i:i+cfg["window_size"], j:j+cfg["window_size"]]
if np.sum(patch_gt > 0) > 0: # Only keep if patch has labels for this split
self.patches.append(self.data[:, i:i+cfg["window_size"], j:j+cfg["window_size"]])
self.labels.append(patch_gt)
self.patches = np.array(self.patches)
self.labels = np.array(self.labels)
def __len__(self): return len(self.patches)
def __getitem__(self, idx):
patch, label = self.patches[idx], self.labels[idx]
if self.augment:
if random.random() > 0.5:
patch = np.flip(patch, axis=2).copy()
label = np.flip(label, axis=1).copy()
if random.random() > 0.5:
patch = np.flip(patch, axis=1).copy()
label = np.flip(label, axis=0).copy()
return torch.from_numpy(patch), torch.from_numpy(label)
# 2. MODEL DEFINITION
class SegFormerHSI(nn.Module):
def __init__(self, in_ch, num_cl):
super().__init__()
# Reducer compresses spectral dimension to 3 channels for SegFormer input
self.reducer = nn.Conv2d(in_ch, 3, kernel_size=1)
self.model = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/mit-b0",
num_labels=num_cl,
ignore_mismatched_sizes=True
)
def forward(self, x):
x = self.reducer(x)
out = self.model(x)
# Upsample logits to original patch size
return nn.functional.interpolate(out.logits, size=x.shape[-2:], mode="bilinear", align_corners=False)
# 3. VISUALIZATION FUNCTIONS
def plot_results(train_gt, full_gt, pred_map):
error_display = np.full(full_gt.shape + (3,), 1.0)
mask = full_gt > 0
error_display[mask] = [1, 0, 0] # Red for errors
error_display[mask & (full_gt == pred_map)] = [0, 1, 0] # Green for correct
fig, ax = plt.subplots(1, 3, figsize=(18, 6))
# Train Mask
train_view = train_gt.astype(float)
train_view[train_view == 0] = np.nan
ax[0].imshow(train_view, cmap='nipy_spectral')
ax[0].set_title("Training Pixels (20%)")
# Prediction
ax[1].imshow(pred_map, cmap='nipy_spectral')
ax[1].set_title("Full Prediction Map")
# Error Map
ax[2].imshow(error_display)
ax[2].set_title("Error Map (Green=Correct)")
for a in ax: a.axis('off')
plt.show()
# 4. MAIN EXECUTION
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_ds = HSIDataset(CONFIG, is_train=True)
test_ds = HSIDataset(CONFIG, is_train=False, augment=False)
train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True)
model = SegFormerHSI(CONFIG["in_channels"], CONFIG["num_classes"]).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore background class
print(f"Dataset: Indian Pines | Training Patches: {len(train_ds)}")
# Training
model.train()
for epoch in range(CONFIG["epochs"]):
total_loss = 0
for imgs, masks in train_loader:
imgs, masks = imgs.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
print(f"Epoch {epoch} | Loss: {total_loss/len(train_loader):.4f}")
# Evaluation
model.eval()
with torch.no_grad():
full_img = torch.from_numpy(test_ds.data).unsqueeze(0).to(device)
pred_map = torch.argmax(model(full_img), dim=1).squeeze(0).cpu().numpy()
# Stats on Test Set
mask = test_ds.gt > 0
y_true = test_ds.gt[mask]
y_pred = pred_map[mask]
print(f"\n--- INDIAN PINES RESULTS ---")
print(f"Overall Accuracy: {accuracy_score(y_true, y_pred):.4f}")
print(f"Kappa: {cohen_kappa_score(y_true, y_pred):.4f}")
# Final Visualizations
raw_gt = loadmat(CONFIG["gt_path"])
gt_key = [k for k in raw_gt.keys() if not k.startswith('__')][0]
full_gt = raw_gt[gt_key].astype(np.int64)
plot_results(train_ds.gt, full_gt, pred_map)
if __name__ == "__main__":
main()
Indian Pines


import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import random
from scipy.io import loadmat
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerForSemanticSegmentation
from sklearn.metrics import cohen_kappa_score, accuracy_score, classification_report
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_results_summary(train_gt, full_gt, pred_map, dataset_name):
# Prepare Error Map: 1 for correct, 0 for error
# Only evaluate where ground truth exists (full_gt > 0)
error_map = np.zeros_like(full_gt, dtype=float)
mask = full_gt > 0
error_map[mask] = (full_gt[mask] == pred_map[mask]).astype(float)
# For visualization, we make mistakes Red (0) and correct pixels Green (1)
# Background stays White/Transparent
error_display = np.full(full_gt.shape + (3,), 1.0) # White background
error_display[full_gt > 0] = [1, 0, 0] # Default Red (Error)
error_display[(full_gt > 0) & (full_gt == pred_map)] = [0, 1, 0] # Green (Correct)
plt.figure(figsize=(18, 6))
# 1. Training Mask
plt.subplot(1, 3, 1)
train_view = train_gt.astype(float)
train_view[train_view == 0] = np.nan
plt.imshow(train_view, cmap='nipy_spectral')
plt.title(f"Training Mask (Used pixels)")
plt.axis('off')
# 2. Prediction
plt.subplot(1, 3, 2)
plt.imshow(pred_map, cmap='nipy_spectral')
plt.title(f"Model Prediction (Full Map)")
plt.axis('off')
# 3. Error Map (Green = Correct, Red = Wrong)
plt.subplot(1, 3, 3)
plt.imshow(error_display)
plt.title(f"Error Map (Green=Correct, Red=Error)")
plt.axis('off')
plt.tight_layout()
plt.show()
def plot_train_vs_prediction(train_gt, pred_map, dataset_name):
# Set background (0) to NaN for better visualization (appears white/empty)
train_display = train_gt.astype(float)
train_display[train_display == 0] = np.nan
pred_display = pred_map.astype(float)
# Optional: mask prediction with where labels actually exist in reality
# pred_display[dataset.gt == 0] = np.nan
plt.figure(figsize=(14, 7))
# Left: Training Mask
plt.subplot(1, 2, 1)
plt.imshow(train_display, cmap='nipy_spectral')
plt.title(f"{dataset_name}: Training Pixels (20%)")
plt.axis('off')
# Right: Model Prediction
plt.subplot(1, 2, 2)
plt.imshow(pred_display, cmap='nipy_spectral')
plt.title(f"{dataset_name}: SegFormer Full Prediction")
plt.axis('off')
plt.tight_layout()
plt.show()
def plot_confusion_matrix(y_true, y_pred, dataset_name):
cm = confusion_matrix(y_true, y_pred)
# Normalize by row (true labels)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(10, 8))
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues")
plt.title(f"Normalized Confusion Matrix: {dataset_name}")
plt.ylabel("True Class")
plt.xlabel("Predicted Class")
plt.show()
# --- CONFIGURATION ---
CONFIG = {
"dataset": "PaviaU",
"data_path": "PaviaU.mat",
"gt_path": "PaviaU_gt.mat",
"in_channels": 103,
"num_classes": 10,
"window_size": 64,
"stride": 8,
"train_ratio": 0.2, # Use 20% of pixels for training
"batch_size": 16,
"epochs": 60,
"lr": 1e-4
}
# 1. DATASET WITH SPATIAL SPLIT
class HSIDataset(Dataset):
def __init__(self, cfg, is_train=True, augment=True):
raw_data = loadmat(cfg["data_path"])
raw_gt = loadmat(cfg["gt_path"])
data_key = "paviaU" if cfg["dataset"] == "PaviaU" else "indian_pines_corrected"
gt_key = "paviaU_gt" if cfg["dataset"] == "PaviaU" else "indian_pines_gt"
data = raw_data[data_key].astype(np.float32)
gt = raw_gt[gt_key].astype(np.int64)
# Normalize
data = (data - np.min(data)) / (np.max(data) - np.min(data))
self.data = np.transpose(data, (2, 0, 1))
# Create Train/Test Mask
# Only split labeled pixels (gt > 0)
labeled_indices = np.where(gt > 0)
num_labeled = len(labeled_indices[0])
indices = np.arange(num_labeled)
np.random.seed(42)
np.random.shuffle(indices)
train_count = int(num_labeled * cfg["train_ratio"])
train_idx = indices[:train_count]
test_idx = indices[train_count:]
split_gt = np.zeros_like(gt)
if is_train:
split_gt[labeled_indices[0][train_idx], labeled_indices[1][train_idx]] = gt[labeled_indices[0][train_idx], labeled_indices[1][train_idx]]
else:
split_gt[labeled_indices[0][test_idx], labeled_indices[1][test_idx]] = gt[labeled_indices[0][test_idx], labeled_indices[1][test_idx]]
self.gt = split_gt
self.augment = augment and is_train
# Generate Patches
self.patches, self.labels = [], []
c, h, w = self.data.shape
for i in range(0, h - cfg["window_size"] + 1, cfg["stride"]):
for j in range(0, w - cfg["window_size"] + 1, cfg["stride"]):
patch_gt = self.gt[i:i+cfg["window_size"], j:j+cfg["window_size"]]
# Only keep patch if it contains labeled pixels for this split
if np.sum(patch_gt > 0) > 0:
self.patches.append(self.data[:, i:i+cfg["window_size"], j:j+cfg["window_size"]])
self.labels.append(patch_gt)
self.patches = np.array(self.patches)
self.labels = np.array(self.labels)
def __len__(self): return len(self.patches)
def __getitem__(self, idx):
patch, label = self.patches[idx], self.labels[idx]
if self.augment:
if random.random() > 0.5: patch = np.flip(patch, axis=2).copy(); label = np.flip(label, axis=1).copy()
if random.random() > 0.5: patch = np.flip(patch, axis=1).copy(); label = np.flip(label, axis=0).copy()
return torch.from_numpy(patch), torch.from_numpy(label)
# 2. MODEL & METRICS (Same as before)
class SegFormerHSI(nn.Module):
def __init__(self, in_ch, num_cl):
super().__init__()
self.reducer = nn.Conv2d(in_ch, 3, kernel_size=1)
self.model = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/mit-b0", num_labels=num_cl, ignore_mismatched_sizes=True
)
def forward(self, x):
x = self.reducer(x)
out = self.model(x)
return nn.functional.interpolate(out.logits, size=x.shape[-2:], mode="bilinear", align_corners=False)
# 3. MAIN TRAINING & VALIDATION
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Datasets
train_ds = HSIDataset(CONFIG, is_train=True)
test_ds = HSIDataset(CONFIG, is_train=False, augment=False)
train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True)
model = SegFormerHSI(CONFIG["in_channels"], CONFIG["num_classes"]).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
criterion = nn.CrossEntropyLoss(ignore_index=0)
print(f"Train Patches: {len(train_ds)} | Test Patches: {len(test_ds)}")
# Train Loop
for epoch in range(CONFIG["epochs"]):
model.train()
for imgs, masks in train_loader:
imgs, masks = imgs.to(device), masks.to(device)
optimizer.zero_grad(); loss = criterion(model(imgs), masks); loss.backward(); optimizer.step()
if epoch % 10 == 0: print(f"Epoch {epoch} complete.")
# Evaluation on the TEST split only
model.eval()
with torch.no_grad():
full_img = torch.from_numpy(test_ds.data).unsqueeze(0).to(device)
pred_map = torch.argmax(model(full_img), dim=1).squeeze(0).cpu().numpy()
# Mask to only evaluate pixels assigned to TEST split
mask = test_ds.gt > 0
oa = accuracy_score(test_ds.gt[mask], pred_map[mask])
kappa = cohen_kappa_score(test_ds.gt[mask], pred_map[mask])
print(f"\n--- TEST SET RESULTS ---")
print(f"Overall Accuracy: {oa:.4f}")
print(f"Kappa: {kappa:.4f}")
print("\nClass-wise Report:")
print(classification_report(test_ds.gt[mask], pred_map[mask]))
plot_confusion_matrix(test_ds.gt[mask], pred_map[mask], CONFIG["dataset"])
plot_train_vs_prediction(train_ds.gt, pred_map, CONFIG["dataset"])
mat_gt = loadmat(CONFIG["gt_path"])
gt_key = "paviaU_gt" if CONFIG["dataset"] == "PaviaU" else "indian_pines_gt"
full_gt = mat_gt[gt_key].astype(np.int64)
# 2. Now call the plot function with the newly defined full_gt
print("Generating Results Plots...")
plot_results_summary(
train_gt=train_ds.gt,
full_gt=full_gt,
pred_map=pred_map,
dataset_name=CONFIG["dataset"]
)
if __name__ == "__main__":
main()