Leggendo la documentazione delle libreria HyperCoast ho trovato riferimento a Moe Vae. Si tratta di una rete neurale complessa basato su
Mixture-of-Experts Variational Autoencoder for Clustering and Generating from Similarity-Based Representations on Single Cell Data
Andreas Kopf, Vincent Fortuin, Vignesh Ram Somnath, Manfred Claassen
https://arxiv.org/abs/1910.07763
che ha il vantaggio di lavorare su una moltitudine di dati come quelli iperspettrali
Ho provato tramite AI, con i dati di Indian Pines
![]() |
| A sinistra verita'a terra A destra risultato del modello |
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from scipy.signal import medfilt2d
# --- 1. ROBUST DATA LOADING ---
def load_indian_pines():
data = sio.loadmat('Indian_pines_corrected.mat')['indian_pines_corrected']
gt = sio.loadmat('Indian_pines_gt.mat')['indian_pines_gt']
h, w, b = data.shape
# Conditional cleanup for the 219 index error
if b > 200:
ignored_bands = [103, 104, 105, 106, 107, 108, 149, 150, 151, 152,
153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163]
data = np.delete(data, [i for i in ignored_bands if i < b], axis=2)
x = data.reshape(-1, data.shape[2]).astype(float)
scaler = StandardScaler()
x_scaled = scaler.fit_transform(x)
return torch.tensor(x_scaled, dtype=torch.float32), gt, h, w, data.shape[2]
# --- 2. IMPROVED MODEL ---
class Expert(nn.Module):
def __init__(self, input_dim, latent_dim):
super().__init__()
self.enc = nn.Sequential(nn.Linear(input_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU())
self.mu = nn.Linear(64, latent_dim)
self.logvar = nn.Linear(64, latent_dim)
self.dec = nn.Sequential(nn.Linear(latent_dim, 64), nn.ReLU(), nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, input_dim))
def forward(self, x):
h = self.enc(x)
mu, lv = self.mu(h), self.logvar(h)
std = torch.exp(0.5 * lv)
z = mu + torch.randn_like(std) * std
return self.dec(z), mu, lv
class MoESimVAE(nn.Module):
def __init__(self, input_dim, latent_dim, num_experts=4):
super().__init__()
self.experts = nn.ModuleList([Expert(input_dim, latent_dim) for _ in range(num_experts)])
self.gate = nn.Sequential(nn.Linear(input_dim, 64), nn.ReLU(), nn.Linear(64, num_experts), nn.Softmax(dim=-1))
def forward(self, x):
w = self.gate(x)
recons, mus, lvs = [], [], []
final_recon = 0
for i, exp in enumerate(self.experts):
r, m, l = exp(x)
final_recon += w[:, i].unsqueeze(1) * r
mus.append(m); lvs.append(l)
return final_recon, mus, lvs, w
# --- 3. THE "SIM" LOSS ---
def compute_loss(recon, x, mus, lvs, w, sigma=0.5):
# Reconstruction + KLD
mse = F.mse_loss(recon, x, reduction='sum')
kld = 0
comb_mu = torch.zeros_like(mus[0])
for i in range(len(mus)):
kld += w[:, i].mean() * -0.5 * torch.sum(1 + lvs[i] - mus[i].pow(2) - lvs[i].exp())
comb_mu += w[:, i].unsqueeze(1) * mus[i]
# RBF Similarity (using a subset for speed/stability)
subset_idx = torch.randperm(x.size(0))[:64]
xs, zs = x[subset_idx], comb_mu[subset_idx]
dist_x = torch.cdist(xs, xs).pow(2)
dist_z = torch.cdist(zs, zs).pow(2)
k_x = torch.exp(-dist_x / (2 * sigma**2))
k_z = torch.exp(-dist_z / (2 * sigma**2))
sim_loss = F.mse_loss(k_z, k_x) * 50.0 # High weight for Sim
# Entropy to prevent expert collapse
entropy = -torch.sum(w.mean(0) * torch.log(w.mean(0) + 1e-8))
return mse + kld + sim_loss - (0.5 * entropy)
# --- 4. TRAIN AND MAP ---
def main():
x_tensor, gt, h, w, b = load_indian_pines()
loader = DataLoader(TensorDataset(x_tensor), batch_size=64, shuffle=True)
model = MoESimVAE(b, 25)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
print("Training... (Aiming for better separation)")
for epoch in range(30):
for batch in loader:
opt.zero_grad()
r, m, l, wg = model(batch[0])
loss = compute_loss(r, batch[0], m, l, wg)
loss.backward()
opt.step()
# Extract & Predict
model.eval()
with torch.no_grad():
_, mus, _, wg = model(x_tensor)
z = sum(wg[:, i].unsqueeze(1) * mus[i] for i in range(len(mus))).numpy()
y = gt.ravel()
labeled = np.where(y > 0)[0]
clf = SVC(kernel='rbf', C=10) # RBF kernel for the classifier too
clf.fit(z[labeled[::10]], y[labeled[::10]]) # Train on 10%
preds = clf.predict(z).reshape(h, w)
preds[gt == 0] = 0
# SPATIAL CLEANING (The secret sauce)
preds_clean = medfilt2d(preds.astype(float), kernel_size=3)
preds_clean[gt == 0] = 0
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1); plt.imshow(gt, cmap='nipy_spectral'); plt.title("Ground Truth")
plt.subplot(1, 2, 2); plt.imshow(preds_clean, cmap='nipy_spectral'); plt.title("MoE-Sim-VAE (Cleaned)")
plt.show()
if __name__ == "__main__":
main()

Nessun commento:
Posta un commento