sabato 28 febbraio 2026

Mixture-of-Experts Variational Autoencoder

Leggendo la documentazione delle libreria HyperCoast ho trovato riferimento a Moe Vae. Si tratta di una rete neurale complessa basato su 


Mixture-of-Experts Variational Autoencoder for Clustering and Generating from Similarity-Based Representations on Single Cell Data 
Andreas Kopf, Vincent Fortuin, Vignesh Ram Somnath, Manfred Claassen

https://arxiv.org/abs/1910.07763

che ha il vantaggio di lavorare su una moltitudine di dati come quelli iperspettrali

 Ho provato tramite AI, con i dati di Indian Pines 

A sinistra verita'a terra A destra risultato del modello

 

 

import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from scipy.signal import medfilt2d

# --- 1. ROBUST DATA LOADING ---
def load_indian_pines():
data = sio.loadmat('Indian_pines_corrected.mat')['indian_pines_corrected']
gt = sio.loadmat('Indian_pines_gt.mat')['indian_pines_gt']
h, w, b = data.shape
# Conditional cleanup for the 219 index error
if b > 200:
ignored_bands = [103, 104, 105, 106, 107, 108, 149, 150, 151, 152,
153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163]
data = np.delete(data, [i for i in ignored_bands if i < b], axis=2)
x = data.reshape(-1, data.shape[2]).astype(float)
scaler = StandardScaler()
x_scaled = scaler.fit_transform(x)
return torch.tensor(x_scaled, dtype=torch.float32), gt, h, w, data.shape[2]

# --- 2. IMPROVED MODEL ---
class Expert(nn.Module):
def __init__(self, input_dim, latent_dim):
super().__init__()
self.enc = nn.Sequential(nn.Linear(input_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU())
self.mu = nn.Linear(64, latent_dim)
self.logvar = nn.Linear(64, latent_dim)
self.dec = nn.Sequential(nn.Linear(latent_dim, 64), nn.ReLU(), nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, input_dim))

def forward(self, x):
h = self.enc(x)
mu, lv = self.mu(h), self.logvar(h)
std = torch.exp(0.5 * lv)
z = mu + torch.randn_like(std) * std
return self.dec(z), mu, lv

class MoESimVAE(nn.Module):
def __init__(self, input_dim, latent_dim, num_experts=4):
super().__init__()
self.experts = nn.ModuleList([Expert(input_dim, latent_dim) for _ in range(num_experts)])
self.gate = nn.Sequential(nn.Linear(input_dim, 64), nn.ReLU(), nn.Linear(64, num_experts), nn.Softmax(dim=-1))

def forward(self, x):
w = self.gate(x)
recons, mus, lvs = [], [], []
final_recon = 0
for i, exp in enumerate(self.experts):
r, m, l = exp(x)
final_recon += w[:, i].unsqueeze(1) * r
mus.append(m); lvs.append(l)
return final_recon, mus, lvs, w

# --- 3. THE "SIM" LOSS ---
def compute_loss(recon, x, mus, lvs, w, sigma=0.5):
# Reconstruction + KLD
mse = F.mse_loss(recon, x, reduction='sum')
kld = 0
comb_mu = torch.zeros_like(mus[0])
for i in range(len(mus)):
kld += w[:, i].mean() * -0.5 * torch.sum(1 + lvs[i] - mus[i].pow(2) - lvs[i].exp())
comb_mu += w[:, i].unsqueeze(1) * mus[i]

# RBF Similarity (using a subset for speed/stability)
subset_idx = torch.randperm(x.size(0))[:64]
xs, zs = x[subset_idx], comb_mu[subset_idx]
dist_x = torch.cdist(xs, xs).pow(2)
dist_z = torch.cdist(zs, zs).pow(2)
k_x = torch.exp(-dist_x / (2 * sigma**2))
k_z = torch.exp(-dist_z / (2 * sigma**2))
sim_loss = F.mse_loss(k_z, k_x) * 50.0 # High weight for Sim
# Entropy to prevent expert collapse
entropy = -torch.sum(w.mean(0) * torch.log(w.mean(0) + 1e-8))
return mse + kld + sim_loss - (0.5 * entropy)

# --- 4. TRAIN AND MAP ---
def main():
x_tensor, gt, h, w, b = load_indian_pines()
loader = DataLoader(TensorDataset(x_tensor), batch_size=64, shuffle=True)
model = MoESimVAE(b, 25)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

print("Training... (Aiming for better separation)")
for epoch in range(30):
for batch in loader:
opt.zero_grad()
r, m, l, wg = model(batch[0])
loss = compute_loss(r, batch[0], m, l, wg)
loss.backward()
opt.step()

# Extract & Predict
model.eval()
with torch.no_grad():
_, mus, _, wg = model(x_tensor)
z = sum(wg[:, i].unsqueeze(1) * mus[i] for i in range(len(mus))).numpy()
y = gt.ravel()
labeled = np.where(y > 0)[0]
clf = SVC(kernel='rbf', C=10) # RBF kernel for the classifier too
clf.fit(z[labeled[::10]], y[labeled[::10]]) # Train on 10%
preds = clf.predict(z).reshape(h, w)
preds[gt == 0] = 0
# SPATIAL CLEANING (The secret sauce)
preds_clean = medfilt2d(preds.astype(float), kernel_size=3)
preds_clean[gt == 0] = 0

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1); plt.imshow(gt, cmap='nipy_spectral'); plt.title("Ground Truth")
plt.subplot(1, 2, 2); plt.imshow(preds_clean, cmap='nipy_spectral'); plt.title("MoE-Sim-VAE (Cleaned)")
plt.show()

if __name__ == "__main__":
main()

Nessun commento:

Posta un commento

Kernel Panic QrCode

 In tanti anni ho visto qualche kernel panic, ma in questo formato non mi era mai successo    la cosa curiosa che al riavvio successivo ness...