This commit is contained in:
Victor Phan
2026-01-24 15:48:52 +00:00
parent f264d1856b
commit fb8a107e77
6 changed files with 1760 additions and 0 deletions

270
sen12ms_cr_dataLoader.py Executable file
View File

@@ -0,0 +1,270 @@
"""
Generic data loading routines for the SEN12MS-CR dataset of corresponding Sentinel 1,
Sentinel 2 and cloudy Sentinel 2 data.
The SEN12MS-CR class is meant to provide a set of helper routines for loading individual
image patches as well as triplets of patches from the dataset. These routines can easily
be wrapped or extended for use with many deep learning frameworks or as standalone helper
methods. For an example use case please see the "main" routine at the end of this file.
NOTE: Some folder/file existence and validity checks are implemented but it is
by no means complete.
Authors: Patrick Ebel (patrick.ebel@tum.de), Lloyd Hughes (lloyd.hughes@tum.de),
based on the exemplary data loader code of https://mediatum.ub.tum.de/1474000, with minimal modifications applied.
"""
import os
import rasterio
import numpy as np
from enum import Enum
from glob import glob
class S1Bands(Enum):
VV = 1
VH = 2
ALL = [VV, VH]
NONE = []
class S2Bands(Enum):
B01 = aerosol = 1
B02 = blue = 2
B03 = green = 3
B04 = red = 4
B05 = re1 = 5
B06 = re2 = 6
B07 = re3 = 7
B08 = nir1 = 8
B08A = nir2 = 9
B09 = vapor = 10
B10 = cirrus = 11
B11 = swir1 = 12
B12 = swir2 = 13
ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12]
RGB = [B04, B03, B02]
NONE = []
class Seasons(Enum):
SPRING = "ROIs1158_spring"
SUMMER = "ROIs1868_summer"
FALL = "ROIs1970_fall"
WINTER = "ROIs2017_winter"
ALL = [SPRING, SUMMER, FALL, WINTER]
class Sensor(Enum):
s1 = "s1"
s2 = "s2"
s2cloudy = "s2cloudy"
# Note: The order in which you request the bands is the same order they will be returned in.
class SEN12MSCRDataset:
def __init__(self, base_dir):
self.base_dir = base_dir
if not os.path.exists(self.base_dir):
raise Exception(
"The specified base_dir for SEN12MS-CR dataset does not exist")
"""
Returns a list of scene ids for a specific season.
"""
def get_scene_ids(self, season):
season = Seasons(season).value
path = os.path.join(self.base_dir, season)
if not os.path.exists(path):
raise NameError("Could not find season {} in base directory {}".format(
season, self.base_dir))
# add all dirs except "s2_cloudy" (which messes with subsequent string splits)
scene_list = [os.path.basename(s)
for s in glob(os.path.join(path, "*")) if "s2_cloudy" not in s]
scene_list = [int(s.split("_")[1]) for s in scene_list]
return set(scene_list)
"""
Returns a list of patch ids for a specific scene within a specific season
"""
def get_patch_ids(self, season, scene_id):
season = Seasons(season).value
path = os.path.join(self.base_dir, season, f"s1_{scene_id}")
if not os.path.exists(path):
raise NameError(
"Could not find scene {} within season {}".format(scene_id, season))
patch_ids = [os.path.splitext(os.path.basename(p))[0]
for p in glob(os.path.join(path, "*"))]
patch_ids = [int(p.rsplit("_", 1)[1].split("p")[1]) for p in patch_ids]
return patch_ids
"""
Return a dict of scene ids and their corresponding patch ids.
key => scene_ids, value => list of patch_ids
"""
def get_season_ids(self, season):
season = Seasons(season).value
ids = {}
scene_ids = self.get_scene_ids(season)
for sid in scene_ids:
ids[sid] = self.get_patch_ids(season, sid)
return ids
"""
Returns raster data and image bounds for the defined bands of a specific patch
This method only loads a sinlge patch from a single sensor as defined by the bands specified
"""
def get_patch(self, season, scene_id, patch_id, bands):
season = Seasons(season).value
sensor = None
if isinstance(bands, (list, tuple)):
b = bands[0]
else:
b = bands
if isinstance(b, S1Bands):
sensor = Sensor.s1.value
bandEnum = S1Bands
elif isinstance(b, S2Bands):
sensor = Sensor.s2.value
bandEnum = S2Bands
else:
raise Exception("Invalid bands specified")
if isinstance(bands, (list, tuple)):
bands = [b.value for b in bands]
else:
bands = bands.value
scene = "{}_{}".format(sensor, scene_id)
filename = "{}_{}_p{}.tif".format(season, scene, patch_id)
patch_path = os.path.join(self.base_dir, season, scene, filename)
with rasterio.open(patch_path) as patch:
data = patch.read(bands)
bounds = patch.bounds
if len(data.shape) == 2:
data = np.expand_dims(data, axis=0)
return data, bounds
"""
Returns a triplet of patches. S1, S2 and cloudy S2 as well as the geo-bounds of the patch
"""
def get_s1s2s2cloudy_triplet(self, season, scene_id, patch_id, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
s1, bounds = self.get_patch(season, scene_id, patch_id, s1_bands)
s2, _ = self.get_patch(season, scene_id, patch_id, s2_bands)
s2cloudy, _ = self.get_patch(season, scene_id, patch_id, s2cloudy_bands)
return s1, s2, s2cloudy, bounds
"""
Returns a triplet of numpy arrays with dimensions D, B, W, H where D is the number of patches specified
using scene_ids and patch_ids and B is the number of bands for S1, S2 or cloudy S2
"""
def get_triplets(self, season, scene_ids=None, patch_ids=None, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
season = Seasons(season)
scene_list = []
patch_list = []
bounds = []
s1_data = []
s2_data = []
s2cloudy_data = []
# This is due to the fact that not all patch ids are available in all scenes
# And not all scenes exist in all seasons
if isinstance(scene_ids, list) and isinstance(patch_ids, list):
raise Exception("Only scene_ids or patch_ids can be a list, not both.")
if scene_ids is None:
scene_list = self.get_scene_ids(season)
else:
try:
scene_list.extend(scene_ids)
except TypeError:
scene_list.append(scene_ids)
if patch_ids is not None:
try:
patch_list.extend(patch_ids)
except TypeError:
patch_list.append(patch_ids)
for sid in scene_list:
if patch_ids is None:
patch_list = self.get_patch_ids(season, sid)
for pid in patch_list:
s1, s2, s2cloudy, bound = self.get_s1s2s2cloudy_triplet(
season, sid, pid, s1_bands, s2_bands, s2cloudy_bands)
s1_data.append(s1)
s2_data.append(s2)
s2cloudy_data.append(s2cloudy)
bounds.append(bound)
return np.stack(s1_data, axis=0), np.stack(s2_data, axis=0), np.stack(s2cloudy_data, axis=0), bounds
if __name__ == "__main__":
import time
# Load the dataset specifying the base directory
sen12mscr = SEN12MSCRDataset(".")
spring_ids = sen12mscr.get_season_ids(Seasons.SPRING)
cnt_patches = sum([len(pids) for pids in spring_ids.values()])
print("Spring: {} scenes with a total of {} patches".format(
len(spring_ids), cnt_patches))
start = time.time()
# Load the RGB bands of the first S2 patch in scene 8
SCENE_ID = 8
s2_rgb_patch, bounds = sen12mscr.get_patch(Seasons.SPRING, SCENE_ID,
spring_ids[SCENE_ID][0], bands=S2Bands.RGB)
print("Time Taken {}s".format(time.time() - start))
print("S2 RGB: {} Bounds: {}".format(s2_rgb_patch.shape, bounds))
print("\n")
# Load a triplet of patches from the first three scenes of Spring - all S1 bands, NDVI S2 bands, and NDVI S2 cloudy bands
i = 0
start = time.time()
for scene_id, patch_ids in spring_ids.items():
if i >= 3:
break
s1, s2, s2cloudy, bounds = sen12mscr.get_s1s2s2cloudy_triplet(Seasons.SPRING, scene_id, patch_ids[0], s1_bands=S1Bands.ALL,
s2_bands=[S2Bands.red, S2Bands.nir1], s2cloudy_bands=[S2Bands.red, S2Bands.nir1])
print(
f"Scene: {scene_id}, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}, Bounds: {bounds}")
i += 1
print("Time Taken {}s".format(time.time() - start))
print("\n")
start = time.time()
# Load all bands of all patches in a specified scene (scene 106)
s1, s2, s2cloudy, _ = sen12mscr.get_triplets(Seasons.SPRING, 106, s1_bands=S1Bands.ALL,
s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL)
print(f"Scene: 106, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}")
print("Time Taken {}s".format(time.time() - start))