import xarray as xr
import numpy as np
import rasterio
from rasterio.transform import from_origin
from rasterio.warp import reproject, Resampling
# --- INPUT/OUTPUT ---
nc_file = "EMIT_L2A_RFL_001_20250609T081616_2516005_003.nc"
output_img = "emit_georef.img"
# --- OPEN DATASETS (lazy) ---
ds_ref = xr.open_dataset(nc_file)
ds_params = xr.open_dataset(nc_file, group="sensor_band_parameters")
ds_loc = xr.open_dataset(nc_file, group="location")
refl = ds_ref["reflectance"] # (downtrack, crosstrack, bands) — lazy
bands = refl.sizes["bands"]
dt = refl.sizes["downtrack"] # original downtrack axis
ct = refl.sizes["crosstrack"] # original crosstrack axis
print(f"Raw swath: {dt} downtrack × {ct} crosstrack × {bands} bands")
# --- WAVELENGTHS / FWHM ---
wavelengths = ds_params["wavelengths"].values
fwhm = ds_params["fwhm"].values if "fwhm" in ds_params else np.full_like(wavelengths, 2.5)
# --- LAT / LON ---
lat = ds_loc["lat"].values # (downtrack, crosstrack)
lon = ds_loc["lon"].values
# Diagnose orientation
print(f"lat corners: TL={lat[0,0]:.4f} TR={lat[0,-1]:.4f} BL={lat[-1,0]:.4f} BR={lat[-1,-1]:.4f}")
print(f"lon corners: TL={lon[0,0]:.4f} TR={lon[0,-1]:.4f} BL={lon[-1,0]:.4f} BR={lon[-1,-1]:.4f}")
# --- ROTATE 90° to fix east-west swath orientation ---
# np.rot90 on (downtrack, crosstrack) → (crosstrack, downtrack)
# After rotation: axis-0 = crosstrack (N-S), axis-1 = downtrack (E-W)
lat_r = np.rot90(lat) # (ct, dt)
lon_r = np.rot90(lon)
rows, cols = lat_r.shape # rows = ct, cols = dt
print(f"After rotation: {rows} rows × {cols} cols")
print(f"lat_r corners: TL={lat_r[0,0]:.4f} TR={lat_r[0,-1]:.4f} BL={lat_r[-1,0]:.4f} BR={lat_r[-1,-1]:.4f}")
print(f"lon_r corners: TL={lon_r[0,0]:.4f} TR={lon_r[0,-1]:.4f} BL={lon_r[-1,0]:.4f} BR={lon_r[-1,-1]:.4f}")
# --- TARGET REGULAR GRID ---
lat_min, lat_max = float(lat.min()), float(lat.max())
lon_min, lon_max = float(lon.min()), float(lon.max())
res = 0.000542232520256367 # ~60 m in degrees
n_rows = int(np.ceil((lat_max - lat_min) / res))
n_cols = int(np.ceil((lon_max - lon_min) / res))
print(f"Output grid: {n_cols} cols × {n_rows} rows ({n_rows*n_cols*bands*4/1e9:.2f} GB)")
# North-up destination transform: origin = top-left = (lon_min, lat_max)
dst_transform = from_origin(lon_min, lat_max, res, res)
# Source transform derived from the ROTATED lat/lon extent
src_x_res = (lon_max - lon_min) / cols
src_y_res = (lat_max - lat_min) / rows
src_transform = from_origin(lon_min, lat_max, src_x_res, src_y_res)
# --- CREATE EMPTY OUTPUT FILE ON DISK ---
with rasterio.open(
output_img, "w",
driver="ENVI",
height=n_rows, width=n_cols,
count=bands,
dtype=np.float32,
crs="EPSG:4326",
transform=dst_transform,
) as dst:
pass
# --- WARP BAND BY BAND (low memory: one band at a time) ---
with rasterio.open(output_img, "r+") as dst:
dst_band = np.empty((n_rows, n_cols), dtype=np.float32)
for b in range(bands):
# Load one band: (downtrack, crosstrack)
band_raw = refl.isel(bands=b).values.astype(np.float32)
# Rotate to match corrected orientation: (crosstrack, downtrack)
band_data = np.ascontiguousarray(np.rot90(band_raw))
dst_band[:] = -9999
reproject(
source=band_data,
destination=dst_band,
src_transform=src_transform,
src_crs="EPSG:4326",
dst_transform=dst_transform,
dst_crs="EPSG:4326",
resampling=Resampling.bilinear,
src_nodata=-9999,
dst_nodata=-9999,
)
dst.write(dst_band, b + 1)
if b % 20 == 0:
print(f" band {b+1}/{bands}")
print("Warp complete — writing HDR...")
# --- ENVI HDR ---
hdr_file = output_img.replace(".img", ".hdr")
map_info = (
f"Geographic Lat/Lon, 1, 1, "
f"{lon_min:.10f}, {lat_max:.10f}, "
f"{res:.10f}, {res:.10f}, "
f"WGS-84, units=Degrees"
)
wl_str = ", ".join(f"{w:.5f}" for w in wavelengths)
fwhm_str = ", ".join(f"{v:.5f}" for v in fwhm)
bn_str = ", ".join(f"Band {i+1}" for i in range(bands))
with open(hdr_file, "w") as f:
f.write("ENVI\n")
f.write(f"description = {{{output_img}}}\n")
f.write(f"samples = {n_cols}\n")
f.write(f"lines = {n_rows}\n")
f.write(f"bands = {bands}\n")
f.write("header offset = 0\n")
f.write("file type = ENVI Standard\n")
f.write("data type = 4\n") # 4 = float32
f.write("interleave = bsq\n")
f.write("byte order = 0\n")
f.write(f"map info = {{{map_info}}}\n")
f.write(
'coordinate system string = {GEOGCS["GCS_WGS_1984",'
'DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],'
'PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]]}\n'
)
f.write(f"band names = {{{bn_str}}}\n")
f.write(f"wavelength = {{{wl_str}}}\n")
f.write("wavelength units = Nanometers\n")
f.write(f"fwhm = {{{fwhm_str}}}\n")
f.write("data ignore value = -9999\n")
print("✅ Done:", output_img)
print(f" Extent: lon [{lon_min:.5f} → {lon_max:.5f}] lat [{lat_min:.5f} → {lat_max:.5f}]")