271 lines
9.1 KiB
Python
Executable File
271 lines
9.1 KiB
Python
Executable File
"""
|
|
Generic data loading routines for the SEN12MS-CR dataset of corresponding Sentinel 1,
|
|
Sentinel 2 and cloudy Sentinel 2 data.
|
|
|
|
The SEN12MS-CR class is meant to provide a set of helper routines for loading individual
|
|
image patches as well as triplets of patches from the dataset. These routines can easily
|
|
be wrapped or extended for use with many deep learning frameworks or as standalone helper
|
|
methods. For an example use case please see the "main" routine at the end of this file.
|
|
|
|
NOTE: Some folder/file existence and validity checks are implemented but it is
|
|
by no means complete.
|
|
|
|
Authors: Patrick Ebel (patrick.ebel@tum.de), Lloyd Hughes (lloyd.hughes@tum.de),
|
|
based on the exemplary data loader code of https://mediatum.ub.tum.de/1474000, with minimal modifications applied.
|
|
"""
|
|
|
|
import os
|
|
import rasterio
|
|
|
|
import numpy as np
|
|
|
|
from enum import Enum
|
|
from glob import glob
|
|
|
|
|
|
class S1Bands(Enum):
|
|
VV = 1
|
|
VH = 2
|
|
ALL = [VV, VH]
|
|
NONE = []
|
|
|
|
|
|
class S2Bands(Enum):
|
|
B01 = aerosol = 1
|
|
B02 = blue = 2
|
|
B03 = green = 3
|
|
B04 = red = 4
|
|
B05 = re1 = 5
|
|
B06 = re2 = 6
|
|
B07 = re3 = 7
|
|
B08 = nir1 = 8
|
|
B08A = nir2 = 9
|
|
B09 = vapor = 10
|
|
B10 = cirrus = 11
|
|
B11 = swir1 = 12
|
|
B12 = swir2 = 13
|
|
ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12]
|
|
RGB = [B04, B03, B02]
|
|
NONE = []
|
|
|
|
|
|
class Seasons(Enum):
|
|
SPRING = "ROIs1158_spring"
|
|
SUMMER = "ROIs1868_summer"
|
|
FALL = "ROIs1970_fall"
|
|
WINTER = "ROIs2017_winter"
|
|
ALL = [SPRING, SUMMER, FALL, WINTER]
|
|
|
|
|
|
class Sensor(Enum):
|
|
s1 = "s1"
|
|
s2 = "s2"
|
|
s2cloudy = "s2cloudy"
|
|
|
|
# Note: The order in which you request the bands is the same order they will be returned in.
|
|
|
|
|
|
class SEN12MSCRDataset:
|
|
def __init__(self, base_dir):
|
|
self.base_dir = base_dir
|
|
|
|
if not os.path.exists(self.base_dir):
|
|
raise Exception(
|
|
"The specified base_dir for SEN12MS-CR dataset does not exist")
|
|
|
|
"""
|
|
Returns a list of scene ids for a specific season.
|
|
"""
|
|
|
|
def get_scene_ids(self, season):
|
|
season = Seasons(season).value
|
|
path = os.path.join(self.base_dir, season)
|
|
|
|
if not os.path.exists(path):
|
|
raise NameError("Could not find season {} in base directory {}".format(
|
|
season, self.base_dir))
|
|
|
|
# add all dirs except "s2_cloudy" (which messes with subsequent string splits)
|
|
scene_list = [os.path.basename(s)
|
|
for s in glob(os.path.join(path, "*")) if "s2_cloudy" not in s]
|
|
scene_list = [int(s.split("_")[1]) for s in scene_list]
|
|
return set(scene_list)
|
|
|
|
"""
|
|
Returns a list of patch ids for a specific scene within a specific season
|
|
"""
|
|
|
|
def get_patch_ids(self, season, scene_id):
|
|
season = Seasons(season).value
|
|
path = os.path.join(self.base_dir, season, f"s1_{scene_id}")
|
|
|
|
if not os.path.exists(path):
|
|
raise NameError(
|
|
"Could not find scene {} within season {}".format(scene_id, season))
|
|
|
|
patch_ids = [os.path.splitext(os.path.basename(p))[0]
|
|
for p in glob(os.path.join(path, "*"))]
|
|
patch_ids = [int(p.rsplit("_", 1)[1].split("p")[1]) for p in patch_ids]
|
|
|
|
return patch_ids
|
|
|
|
"""
|
|
Return a dict of scene ids and their corresponding patch ids.
|
|
key => scene_ids, value => list of patch_ids
|
|
"""
|
|
|
|
def get_season_ids(self, season):
|
|
season = Seasons(season).value
|
|
ids = {}
|
|
scene_ids = self.get_scene_ids(season)
|
|
|
|
for sid in scene_ids:
|
|
ids[sid] = self.get_patch_ids(season, sid)
|
|
|
|
return ids
|
|
|
|
"""
|
|
Returns raster data and image bounds for the defined bands of a specific patch
|
|
This method only loads a sinlge patch from a single sensor as defined by the bands specified
|
|
"""
|
|
|
|
def get_patch(self, season, scene_id, patch_id, bands):
|
|
season = Seasons(season).value
|
|
sensor = None
|
|
|
|
if isinstance(bands, (list, tuple)):
|
|
b = bands[0]
|
|
else:
|
|
b = bands
|
|
|
|
if isinstance(b, S1Bands):
|
|
sensor = Sensor.s1.value
|
|
bandEnum = S1Bands
|
|
elif isinstance(b, S2Bands):
|
|
sensor = Sensor.s2.value
|
|
bandEnum = S2Bands
|
|
else:
|
|
raise Exception("Invalid bands specified")
|
|
|
|
if isinstance(bands, (list, tuple)):
|
|
bands = [b.value for b in bands]
|
|
else:
|
|
bands = bands.value
|
|
|
|
scene = "{}_{}".format(sensor, scene_id)
|
|
filename = "{}_{}_p{}.tif".format(season, scene, patch_id)
|
|
patch_path = os.path.join(self.base_dir, season, scene, filename)
|
|
|
|
with rasterio.open(patch_path) as patch:
|
|
data = patch.read(bands)
|
|
bounds = patch.bounds
|
|
|
|
if len(data.shape) == 2:
|
|
data = np.expand_dims(data, axis=0)
|
|
|
|
return data, bounds
|
|
|
|
"""
|
|
Returns a triplet of patches. S1, S2 and cloudy S2 as well as the geo-bounds of the patch
|
|
"""
|
|
|
|
def get_s1s2s2cloudy_triplet(self, season, scene_id, patch_id, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
|
|
s1, bounds = self.get_patch(season, scene_id, patch_id, s1_bands)
|
|
s2, _ = self.get_patch(season, scene_id, patch_id, s2_bands)
|
|
s2cloudy, _ = self.get_patch(season, scene_id, patch_id, s2cloudy_bands)
|
|
|
|
return s1, s2, s2cloudy, bounds
|
|
|
|
"""
|
|
Returns a triplet of numpy arrays with dimensions D, B, W, H where D is the number of patches specified
|
|
using scene_ids and patch_ids and B is the number of bands for S1, S2 or cloudy S2
|
|
"""
|
|
|
|
def get_triplets(self, season, scene_ids=None, patch_ids=None, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
|
|
season = Seasons(season)
|
|
scene_list = []
|
|
patch_list = []
|
|
bounds = []
|
|
s1_data = []
|
|
s2_data = []
|
|
s2cloudy_data = []
|
|
|
|
# This is due to the fact that not all patch ids are available in all scenes
|
|
# And not all scenes exist in all seasons
|
|
if isinstance(scene_ids, list) and isinstance(patch_ids, list):
|
|
raise Exception("Only scene_ids or patch_ids can be a list, not both.")
|
|
|
|
if scene_ids is None:
|
|
scene_list = self.get_scene_ids(season)
|
|
else:
|
|
try:
|
|
scene_list.extend(scene_ids)
|
|
except TypeError:
|
|
scene_list.append(scene_ids)
|
|
|
|
if patch_ids is not None:
|
|
try:
|
|
patch_list.extend(patch_ids)
|
|
except TypeError:
|
|
patch_list.append(patch_ids)
|
|
|
|
for sid in scene_list:
|
|
if patch_ids is None:
|
|
patch_list = self.get_patch_ids(season, sid)
|
|
|
|
for pid in patch_list:
|
|
s1, s2, s2cloudy, bound = self.get_s1s2s2cloudy_triplet(
|
|
season, sid, pid, s1_bands, s2_bands, s2cloudy_bands)
|
|
s1_data.append(s1)
|
|
s2_data.append(s2)
|
|
s2cloudy_data.append(s2cloudy)
|
|
bounds.append(bound)
|
|
|
|
return np.stack(s1_data, axis=0), np.stack(s2_data, axis=0), np.stack(s2cloudy_data, axis=0), bounds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import time
|
|
# Load the dataset specifying the base directory
|
|
sen12mscr = SEN12MSCRDataset(".")
|
|
|
|
spring_ids = sen12mscr.get_season_ids(Seasons.SPRING)
|
|
cnt_patches = sum([len(pids) for pids in spring_ids.values()])
|
|
print("Spring: {} scenes with a total of {} patches".format(
|
|
len(spring_ids), cnt_patches))
|
|
|
|
start = time.time()
|
|
# Load the RGB bands of the first S2 patch in scene 8
|
|
SCENE_ID = 8
|
|
s2_rgb_patch, bounds = sen12mscr.get_patch(Seasons.SPRING, SCENE_ID,
|
|
spring_ids[SCENE_ID][0], bands=S2Bands.RGB)
|
|
print("Time Taken {}s".format(time.time() - start))
|
|
|
|
print("S2 RGB: {} Bounds: {}".format(s2_rgb_patch.shape, bounds))
|
|
|
|
print("\n")
|
|
|
|
# Load a triplet of patches from the first three scenes of Spring - all S1 bands, NDVI S2 bands, and NDVI S2 cloudy bands
|
|
i = 0
|
|
start = time.time()
|
|
for scene_id, patch_ids in spring_ids.items():
|
|
if i >= 3:
|
|
break
|
|
|
|
s1, s2, s2cloudy, bounds = sen12mscr.get_s1s2s2cloudy_triplet(Seasons.SPRING, scene_id, patch_ids[0], s1_bands=S1Bands.ALL,
|
|
s2_bands=[S2Bands.red, S2Bands.nir1], s2cloudy_bands=[S2Bands.red, S2Bands.nir1])
|
|
print(
|
|
f"Scene: {scene_id}, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}, Bounds: {bounds}")
|
|
i += 1
|
|
|
|
print("Time Taken {}s".format(time.time() - start))
|
|
print("\n")
|
|
|
|
start = time.time()
|
|
# Load all bands of all patches in a specified scene (scene 106)
|
|
s1, s2, s2cloudy, _ = sen12mscr.get_triplets(Seasons.SPRING, 106, s1_bands=S1Bands.ALL,
|
|
s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL)
|
|
|
|
print(f"Scene: 106, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}")
|
|
print("Time Taken {}s".format(time.time() - start))
|