Sto provando ad utilizzare Marida, un archivio di immagini supervisionate di Sentinel 2, raccolte in area non mediterranea con 15 categorie
Il dataset e' costituito da patches (immagini geotiff con risoluzione spaziale di 10 m, dimensioni di 2560x2560 m, 256x256 pixels ed 11 bande). In questo senso si deve considerare che le bande a 20 e 60 m sono state ricampionate
import torch
import torch.nn as nn
import torchvision.models as models
import sys
class ResNet50Encoder(nn.Module):
def __init__(self, input_bands=11, output_classes=11):
super(ResNet50Encoder, self).__init__()
resnet = models.resnet50(pretrained=False)
self.encoder = nn.Sequential(
nn.Conv2d(input_bands, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
resnet.bn1,
resnet.relu,
resnet.maxpool,
resnet.layer1,
resnet.layer2,
resnet.layer3,
resnet.layer4,
resnet.avgpool
)
self.fc = nn.Linear(2048, output_classes)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
logits = self.fc(x)
return logits
# Instantiate and load weights
model = ResNet50Encoder(input_bands=11, output_classes=11)
model.load_state_dict(torch.load("model_resnet.pth", map_location="cpu"))
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
import rasterio
import numpy as np
with rasterio.open("test4.tif") as src:
image = src.read().astype(np.float32) # (11, 256, 256)
# Normalize
image = (image - image.min()) / (image.max() - image.min())
import torch
image_tensor = torch.from_numpy(image).unsqueeze(0) # (1, 11, 256, 256)
image_tensor = image_tensor.to(device)
with torch.no_grad():
output = model(image_tensor)
# If it's a classifier:
predicted_class = torch.argmax(output, dim=1)
print("Predicted class:", predicted_class.item())
#classificazione a livello di pixel
output = model(image_tensor) # [1, num_classes, 256, 256]
mask = torch.argmax(output, dim=1).squeeze(0) # shape: [256, 256]
import torch
from torchvision.utils import save_image
from PIL import Image
import numpy as np
# Example: your predicted mask (dtype: torch.Tensor, shape: [H, W])
mask = torch.randint(0, 10, (256, 256), dtype=torch.uint8) # demo mask
# Option 1: using PIL directly (recommended for class masks)
mask_np = mask.numpy().astype(np.uint8)
img = Image.fromarray(mask_np, mode="L") # 'L' = grayscale (8-bit)
img.save("predicted_mask.png")
palette = [
0, 0, 0, # class 0 - black
255, 0, 0, # class 1 - red
0, 255, 0, # class 2 - green
0, 0, 255, # class 3 - blue
255, 255, 0, # class 4 - yellow
255, 0, 255, # class 5 - magenta
0, 255, 255, # class 6 - cyan
128, 128, 128, # class 7 - gray
255, 128, 0, # class 8 - orange
128, 0, 255, # class 9 - violet
0, 128, 255, # class 10 - sky blue
]
# Apply the palette
color_mask = Image.fromarray(mask_np, mode="P")
color_mask.putpalette(palette)
color_mask.save("colored_mask.png")
##########################################################################################
# Define the segmentation model
class ResNetSegmentationFromClassifier(nn.Module):
def __init__(self, encoder_model):
super().__init__()
# Extract the encoder part (everything except avgpool and fc)
resnet = models.resnet50(pretrained=False)
# First re-use the custom first conv layer from the encoder model
first_conv = encoder_model.encoder[0]
# Build encoder using resnet blocks
self.encoder = nn.Sequential(
first_conv,
resnet.bn1,
resnet.relu,
resnet.maxpool,
resnet.layer1,
resnet.layer2,
resnet.layer3,
resnet.layer4
)
# Load weights from the encoder model
# Copy weights for the shared layers
encoder_state_dict = encoder_model.state_dict()
for name, param in self.encoder.named_parameters():
if name in encoder_state_dict:
param.data.copy_(encoder_state_dict[name])
# Decoder to upsample to 256x256
self.decoder = nn.Sequential(
nn.ConvTranspose2d(2048, 512, 2, stride=2), # 8x8 -> 16x16
nn.ReLU(),
nn.ConvTranspose2d(512, 256, 2, stride=2), # 16x16 -> 32x32
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 2, stride=2), # 32x32 -> 64x64
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 2, stride=2), # 64x64 -> 128x128
nn.ReLU(),
nn.ConvTranspose2d(64, 11, 2, stride=2) # 128x128 -> 256x256
)
def forward(self, x):
x = self.encoder(x) # [B, 2048, 8, 8]
x = self.decoder(x) # [B, 11, 256, 256]
return x
# Load the classification model
model = ResNet50Encoder(input_bands=11, output_classes=11)
model.load_state_dict(torch.load("model_resnet.pth", map_location="cpu"))
model.eval()
# Create the segmentation model
segmentation_model = ResNetSegmentationFromClassifier(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
segmentation_model.to(device)
segmentation_model.eval()
# Load and preprocess the image
with rasterio.open("test4.tif") as src:
image = src.read().astype(np.float32) # (11, 256, 256)
# Normalize
image = (image - image.min()) / (image.max() - image.min())
image_tensor = torch.from_numpy(image).unsqueeze(0) # (1, 11, 256, 256)
image_tensor = image_tensor.to(device)
# Perform pixel-based classification
with torch.no_grad():
output = segmentation_model(image_tensor) # [1, 11, 256, 256]
# Create class mask by taking argmax at each pixel
mask = torch.argmax(output, dim=1).squeeze(0) # shape: [256, 256]
mask_np = mask.cpu().numpy().astype(np.uint8)
# Define color palette (11 classes)
palette = [
0, 0, 0, # class 0 - black
255, 0, 0, # class 1 - red
0, 255, 0, # class 2 - green
0, 0, 255, # class 3 - blue
255, 255, 0, # class 4 - yellow
255, 0, 255, # class 5 - magenta
0, 255, 255, # class 6 - cyan
128, 128, 128, # class 7 - gray
255, 128, 0, # class 8 - orange
128, 0, 255, # class 9 - violet
0, 128, 255, # class 10 - sky blue
]
# Save grayscale mask
img = Image.fromarray(mask_np, mode="L")
img.save("predicted_mask.png")
# Save colored mask
color_mask = Image.fromarray(mask_np, mode="P")
color_mask.putpalette(palette)
color_mask.save("colored_mask.png")
# Print class distribution
unique, counts = np.unique(mask_np, return_counts=True)
class_distribution = dict(zip(unique, counts))
print("Pixel class distribution:", class_distribution)