import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, cohen_kappa_score, accuracy_score
import matplotlib.pyplot as plt
import os
# ==========================================
# Configuration
# ==========================================
class Config:
data_path = './'
img_mat = 'PaviaU.mat'
gt_mat = 'PaviaU_gt.mat'
# Hyperparameters
patch_size = 9 # Spatial patch size (9x9)
pca_components = 30 # Reduce 103 bands to 30
num_classes = 9
train_ratio = 0.01 # 1% training data (standard for HSI)
# Model Hyperparameters
embed_dim = 64 # Embedding dimension
num_heads = 4
num_layers = 3 # Number of Transformer blocks
mlp_ratio = 2.0
dropout = 0.1
# Training
batch_size = 64
epochs = 50
learning_rate = 0.001
weight_decay = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg = Config()
# ==========================================
# Data Loading & Preprocessing
# ==========================================
def load_pavia_data(cfg):
"""Loads Pavia University data and applies PCA."""
img_path = os.path.join(cfg.data_path, cfg.img_mat)
gt_path = os.path.join(cfg.data_path, cfg.gt_mat)
# Load .mat files
data = sio.loadmat(img_path)['paviaU']
gt = sio.loadmat(gt_path)['paviaU_gt']
# Apply PCA to reduce spectral dimensionality
h, w, b = data.shape
data_2d = data.reshape(-1, b)
pca = PCA(n_components=cfg.pca_components, whiten=True)
data_pca = pca.fit_transform(data_2d)
data_pca = data_pca.reshape(h, w, cfg.pca_components)
return data_pca, gt
def create_patches(data, gt, patch_size):
"""Creates 3D patches and handles padding."""
h, w, b = data.shape
pad = patch_size // 2
# Pad the image and ground truth
data_padded = np.pad(data, ((pad, pad), (pad, pad), (0, 0)), mode='reflect')
gt_padded = np.pad(gt, ((pad, pad), (pad, pad)), mode='constant', constant_values=0)
patches = []
labels = []
# Iterate over valid pixels (where gt > 0)
y_indices, x_indices = np.where(gt > 0)
for y, x in zip(y_indices, x_indices):
# Extract patch (centered at y, x in original coords -> y+pad, x+pad in padded)
patch = data_padded[y:y+patch_size, x:x+patch_size, :]
patches.append(patch)
labels.append(gt[y, x] - 1) # Labels in Pavia are 1-9, convert to 0-8
return np.array(patches), np.array(labels)
class HSIDataset(Dataset):
def __init__(self, patches, labels):
# Permute to (N, C, H, W) format for Conv2d
self.patches = torch.FloatTensor(patches).permute(0, 3, 1, 2)
self.labels = torch.LongTensor(labels)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.patches[idx], self.labels[idx]
# ==========================================
# SpectralFormer Model (FIXED)
# ==========================================
class MSA(nn.Module):
"""Multi-Head Self-Attention"""
def __init__(self, dim, num_heads, dropout=0.1):
super().__init__()
assert dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.attn_drop = nn.Dropout(dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerBlock(nn.Module):
"""Standard Transformer Block with LayerNorm"""
def __init__(self, dim, num_heads, mlp_ratio=2.0, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MSA(dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class SpectralFormer(nn.Module):
"""
SpectralFormer Implementation (FIXED VERSION).
Key fixes:
1. Uses BatchNorm2d after Conv2d (not LayerNorm) to match (B,C,H,W) format
2. Applies LayerNorm only after flattening to (B, N, C) sequence format
3. Proper positional encoding initialization
"""
def __init__(self, cfg):
super().__init__()
self.patch_size = cfg.patch_size
self.embed_dim = cfg.embed_dim
self.num_layers = cfg.num_layers
# 1. Patch Embedding: 3D patch -> embedded token
# Input: (B, pca_components, patch_size, patch_size)
# Output: (B, embed_dim, 1, 1) since kernel_size=patch_size
self.patch_embed = nn.Sequential(
nn.Conv2d(cfg.pca_components, cfg.embed_dim,
kernel_size=cfg.patch_size, stride=cfg.patch_size),
nn.BatchNorm2d(cfg.embed_dim), # ✓ FIXED: BatchNorm for (B,C,H,W)
nn.GELU()
)
# 2. Positional Encoding: learnable, properly initialized
# Sequence length = 1 (single token per patch)
self.pos_embed = nn.Parameter(torch.randn(1, 1, cfg.embed_dim) * 0.02)
# 3. Transformer Blocks
self.blocks = nn.ModuleList([
TransformerBlock(cfg.embed_dim, cfg.num_heads, cfg.mlp_ratio, cfg.dropout)
for _ in range(cfg.num_layers)
])
# 4. Classification Head
self.norm = nn.LayerNorm(cfg.embed_dim) # ✓ LayerNorm works here: input is (B, 1, embed_dim)
self.head = nn.Linear(cfg.embed_dim, cfg.num_classes)
# Weight initialization
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x):
# x shape: (B, pca_components, patch_size, patch_size)
B = x.shape[0]
# Embed patches: (B, C, 9, 9) -> (B, embed_dim, 1, 1)
x = self.patch_embed(x)
# Flatten to sequence format for Transformer: (B, N, embed_dim)
# N = 1 since spatial dims collapsed to 1x1
x = x.flatten(2).transpose(1, 2) # (B, 1, embed_dim)
# Add Positional Encoding
x = x + self.pos_embed
# Transformer blocks
for blk in self.blocks:
x = blk(x)
# Classification
x = self.norm(x) # (B, 1, embed_dim)
x = x.mean(dim=1) # (B, embed_dim) - global average pooling
x = self.head(x) # (B, num_classes)
return x
# ==========================================
# Training & Evaluation Functions
# ==========================================
def train_model(model, train_loader, val_loader, cfg):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
best_acc = 0.0
train_losses = []
val_accs = []
print(f"Starting Training on {cfg.device}...")
for epoch in range(cfg.epochs):
model.train()
total_loss = 0
for patches, labels in train_loader:
patches, labels = patches.to(cfg.device), labels.to(cfg.device)
optimizer.zero_grad()
outputs = model(patches)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(train_loader)
train_losses.append(avg_loss)
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for patches, labels in val_loader:
patches, labels = patches.to(cfg.device), labels.to(cfg.device)
outputs = model(patches)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
val_acc = 100 * correct / total
val_accs.append(val_acc)
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), 'best_spectralformer.pth')
if (epoch + 1) % 5 == 0:
print(f"Epoch [{epoch+1}/{cfg.epochs}], Loss: {avg_loss:.4f}, Val Acc: {val_acc:.2f}%")
return train_losses, val_accs
def evaluate_model(model, test_loader, cfg):
model.load_state_dict(torch.load('best_spectralformer.pth', map_location=cfg.device))
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for patches, labels in test_loader:
patches = patches.to(cfg.device)
outputs = model(patches)
_, predicted = torch.max(outputs.data, 1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
# Metrics
oa = accuracy_score(all_labels, all_preds)
kappa = cohen_kappa_score(all_labels, all_preds)
# Average Accuracy (per class)
cm = confusion_matrix(all_labels, all_preds)
# Handle division by zero for classes with no samples
class_acc = np.diag(cm) / np.maximum(np.sum(cm, axis=1), 1)
aa = np.mean(class_acc)
print("\n--- Evaluation Results ---")
print(f"Overall Accuracy (OA): {oa * 100:.2f}%")
print(f"Average Accuracy (AA): {aa * 100:.2f}%")
print(f"Kappa Coefficient: {kappa:.4f}")
print("Confusion Matrix:\n", cm)
# Per-class accuracy
print("\nPer-class Accuracy:")
for i, acc in enumerate(class_acc):
print(f" Class {i+1}: {acc * 100:.2f}%")
return oa, aa, kappa
# ==========================================
# Main Execution
# ==========================================
def main():
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# 1. Load Data
print("Loading Data...")
data, gt = load_pavia_data(cfg)
# 2. Create Patches
print("Creating Patches...")
patches, labels = create_patches(data, gt, cfg.patch_size)
# 3. Split Data (Train/Val/Test)
print("Splitting Data...")
X_train, X_temp, y_train, y_temp = train_test_split(
patches, labels, train_size=cfg.train_ratio, stratify=labels, random_state=42
)
X_val, X_test, y_val, y_test = train_test_split(
X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42
)
print(f"Train: {len(y_train)}, Val: {len(y_val)}, Test: {len(y_test)}")
# 4. Dataloaders
train_dataset = HSIDataset(X_train, y_train)
val_dataset = HSIDataset(X_val, y_val)
test_dataset = HSIDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0)
# 5. Initialize Model
print("Initializing Model...")
model = SpectralFormer(cfg).to(cfg.device)
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
# 6. Train
print("\n" + "="*50)
train_losses, val_accs = train_model(model, train_loader, val_loader, cfg)
# 7. Evaluate
print("\n" + "="*50)
evaluate_model(model, test_loader, cfg)
# 8. Plot Loss/Acc
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.plot(val_accs, label='Val Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Validation Accuracy')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_curve.png', dpi=300, bbox_inches='tight')
print("\nTraining curve saved to training_curve.png")
plt.close()
if __name__ == '__main__':
# Create data directory if not exists
os.makedirs(cfg.data_path, exist_ok=True)
main()