init
This commit is contained in:
270
sen12ms_cr_dataLoader.py
Executable file
270
sen12ms_cr_dataLoader.py
Executable file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Generic data loading routines for the SEN12MS-CR dataset of corresponding Sentinel 1,
|
||||
Sentinel 2 and cloudy Sentinel 2 data.
|
||||
|
||||
The SEN12MS-CR class is meant to provide a set of helper routines for loading individual
|
||||
image patches as well as triplets of patches from the dataset. These routines can easily
|
||||
be wrapped or extended for use with many deep learning frameworks or as standalone helper
|
||||
methods. For an example use case please see the "main" routine at the end of this file.
|
||||
|
||||
NOTE: Some folder/file existence and validity checks are implemented but it is
|
||||
by no means complete.
|
||||
|
||||
Authors: Patrick Ebel (patrick.ebel@tum.de), Lloyd Hughes (lloyd.hughes@tum.de),
|
||||
based on the exemplary data loader code of https://mediatum.ub.tum.de/1474000, with minimal modifications applied.
|
||||
"""
|
||||
|
||||
import os
|
||||
import rasterio
|
||||
|
||||
import numpy as np
|
||||
|
||||
from enum import Enum
|
||||
from glob import glob
|
||||
|
||||
|
||||
class S1Bands(Enum):
|
||||
VV = 1
|
||||
VH = 2
|
||||
ALL = [VV, VH]
|
||||
NONE = []
|
||||
|
||||
|
||||
class S2Bands(Enum):
|
||||
B01 = aerosol = 1
|
||||
B02 = blue = 2
|
||||
B03 = green = 3
|
||||
B04 = red = 4
|
||||
B05 = re1 = 5
|
||||
B06 = re2 = 6
|
||||
B07 = re3 = 7
|
||||
B08 = nir1 = 8
|
||||
B08A = nir2 = 9
|
||||
B09 = vapor = 10
|
||||
B10 = cirrus = 11
|
||||
B11 = swir1 = 12
|
||||
B12 = swir2 = 13
|
||||
ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12]
|
||||
RGB = [B04, B03, B02]
|
||||
NONE = []
|
||||
|
||||
|
||||
class Seasons(Enum):
|
||||
SPRING = "ROIs1158_spring"
|
||||
SUMMER = "ROIs1868_summer"
|
||||
FALL = "ROIs1970_fall"
|
||||
WINTER = "ROIs2017_winter"
|
||||
ALL = [SPRING, SUMMER, FALL, WINTER]
|
||||
|
||||
|
||||
class Sensor(Enum):
|
||||
s1 = "s1"
|
||||
s2 = "s2"
|
||||
s2cloudy = "s2cloudy"
|
||||
|
||||
# Note: The order in which you request the bands is the same order they will be returned in.
|
||||
|
||||
|
||||
class SEN12MSCRDataset:
|
||||
def __init__(self, base_dir):
|
||||
self.base_dir = base_dir
|
||||
|
||||
if not os.path.exists(self.base_dir):
|
||||
raise Exception(
|
||||
"The specified base_dir for SEN12MS-CR dataset does not exist")
|
||||
|
||||
"""
|
||||
Returns a list of scene ids for a specific season.
|
||||
"""
|
||||
|
||||
def get_scene_ids(self, season):
|
||||
season = Seasons(season).value
|
||||
path = os.path.join(self.base_dir, season)
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise NameError("Could not find season {} in base directory {}".format(
|
||||
season, self.base_dir))
|
||||
|
||||
# add all dirs except "s2_cloudy" (which messes with subsequent string splits)
|
||||
scene_list = [os.path.basename(s)
|
||||
for s in glob(os.path.join(path, "*")) if "s2_cloudy" not in s]
|
||||
scene_list = [int(s.split("_")[1]) for s in scene_list]
|
||||
return set(scene_list)
|
||||
|
||||
"""
|
||||
Returns a list of patch ids for a specific scene within a specific season
|
||||
"""
|
||||
|
||||
def get_patch_ids(self, season, scene_id):
|
||||
season = Seasons(season).value
|
||||
path = os.path.join(self.base_dir, season, f"s1_{scene_id}")
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise NameError(
|
||||
"Could not find scene {} within season {}".format(scene_id, season))
|
||||
|
||||
patch_ids = [os.path.splitext(os.path.basename(p))[0]
|
||||
for p in glob(os.path.join(path, "*"))]
|
||||
patch_ids = [int(p.rsplit("_", 1)[1].split("p")[1]) for p in patch_ids]
|
||||
|
||||
return patch_ids
|
||||
|
||||
"""
|
||||
Return a dict of scene ids and their corresponding patch ids.
|
||||
key => scene_ids, value => list of patch_ids
|
||||
"""
|
||||
|
||||
def get_season_ids(self, season):
|
||||
season = Seasons(season).value
|
||||
ids = {}
|
||||
scene_ids = self.get_scene_ids(season)
|
||||
|
||||
for sid in scene_ids:
|
||||
ids[sid] = self.get_patch_ids(season, sid)
|
||||
|
||||
return ids
|
||||
|
||||
"""
|
||||
Returns raster data and image bounds for the defined bands of a specific patch
|
||||
This method only loads a sinlge patch from a single sensor as defined by the bands specified
|
||||
"""
|
||||
|
||||
def get_patch(self, season, scene_id, patch_id, bands):
|
||||
season = Seasons(season).value
|
||||
sensor = None
|
||||
|
||||
if isinstance(bands, (list, tuple)):
|
||||
b = bands[0]
|
||||
else:
|
||||
b = bands
|
||||
|
||||
if isinstance(b, S1Bands):
|
||||
sensor = Sensor.s1.value
|
||||
bandEnum = S1Bands
|
||||
elif isinstance(b, S2Bands):
|
||||
sensor = Sensor.s2.value
|
||||
bandEnum = S2Bands
|
||||
else:
|
||||
raise Exception("Invalid bands specified")
|
||||
|
||||
if isinstance(bands, (list, tuple)):
|
||||
bands = [b.value for b in bands]
|
||||
else:
|
||||
bands = bands.value
|
||||
|
||||
scene = "{}_{}".format(sensor, scene_id)
|
||||
filename = "{}_{}_p{}.tif".format(season, scene, patch_id)
|
||||
patch_path = os.path.join(self.base_dir, season, scene, filename)
|
||||
|
||||
with rasterio.open(patch_path) as patch:
|
||||
data = patch.read(bands)
|
||||
bounds = patch.bounds
|
||||
|
||||
if len(data.shape) == 2:
|
||||
data = np.expand_dims(data, axis=0)
|
||||
|
||||
return data, bounds
|
||||
|
||||
"""
|
||||
Returns a triplet of patches. S1, S2 and cloudy S2 as well as the geo-bounds of the patch
|
||||
"""
|
||||
|
||||
def get_s1s2s2cloudy_triplet(self, season, scene_id, patch_id, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
|
||||
s1, bounds = self.get_patch(season, scene_id, patch_id, s1_bands)
|
||||
s2, _ = self.get_patch(season, scene_id, patch_id, s2_bands)
|
||||
s2cloudy, _ = self.get_patch(season, scene_id, patch_id, s2cloudy_bands)
|
||||
|
||||
return s1, s2, s2cloudy, bounds
|
||||
|
||||
"""
|
||||
Returns a triplet of numpy arrays with dimensions D, B, W, H where D is the number of patches specified
|
||||
using scene_ids and patch_ids and B is the number of bands for S1, S2 or cloudy S2
|
||||
"""
|
||||
|
||||
def get_triplets(self, season, scene_ids=None, patch_ids=None, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
|
||||
season = Seasons(season)
|
||||
scene_list = []
|
||||
patch_list = []
|
||||
bounds = []
|
||||
s1_data = []
|
||||
s2_data = []
|
||||
s2cloudy_data = []
|
||||
|
||||
# This is due to the fact that not all patch ids are available in all scenes
|
||||
# And not all scenes exist in all seasons
|
||||
if isinstance(scene_ids, list) and isinstance(patch_ids, list):
|
||||
raise Exception("Only scene_ids or patch_ids can be a list, not both.")
|
||||
|
||||
if scene_ids is None:
|
||||
scene_list = self.get_scene_ids(season)
|
||||
else:
|
||||
try:
|
||||
scene_list.extend(scene_ids)
|
||||
except TypeError:
|
||||
scene_list.append(scene_ids)
|
||||
|
||||
if patch_ids is not None:
|
||||
try:
|
||||
patch_list.extend(patch_ids)
|
||||
except TypeError:
|
||||
patch_list.append(patch_ids)
|
||||
|
||||
for sid in scene_list:
|
||||
if patch_ids is None:
|
||||
patch_list = self.get_patch_ids(season, sid)
|
||||
|
||||
for pid in patch_list:
|
||||
s1, s2, s2cloudy, bound = self.get_s1s2s2cloudy_triplet(
|
||||
season, sid, pid, s1_bands, s2_bands, s2cloudy_bands)
|
||||
s1_data.append(s1)
|
||||
s2_data.append(s2)
|
||||
s2cloudy_data.append(s2cloudy)
|
||||
bounds.append(bound)
|
||||
|
||||
return np.stack(s1_data, axis=0), np.stack(s2_data, axis=0), np.stack(s2cloudy_data, axis=0), bounds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
# Load the dataset specifying the base directory
|
||||
sen12mscr = SEN12MSCRDataset(".")
|
||||
|
||||
spring_ids = sen12mscr.get_season_ids(Seasons.SPRING)
|
||||
cnt_patches = sum([len(pids) for pids in spring_ids.values()])
|
||||
print("Spring: {} scenes with a total of {} patches".format(
|
||||
len(spring_ids), cnt_patches))
|
||||
|
||||
start = time.time()
|
||||
# Load the RGB bands of the first S2 patch in scene 8
|
||||
SCENE_ID = 8
|
||||
s2_rgb_patch, bounds = sen12mscr.get_patch(Seasons.SPRING, SCENE_ID,
|
||||
spring_ids[SCENE_ID][0], bands=S2Bands.RGB)
|
||||
print("Time Taken {}s".format(time.time() - start))
|
||||
|
||||
print("S2 RGB: {} Bounds: {}".format(s2_rgb_patch.shape, bounds))
|
||||
|
||||
print("\n")
|
||||
|
||||
# Load a triplet of patches from the first three scenes of Spring - all S1 bands, NDVI S2 bands, and NDVI S2 cloudy bands
|
||||
i = 0
|
||||
start = time.time()
|
||||
for scene_id, patch_ids in spring_ids.items():
|
||||
if i >= 3:
|
||||
break
|
||||
|
||||
s1, s2, s2cloudy, bounds = sen12mscr.get_s1s2s2cloudy_triplet(Seasons.SPRING, scene_id, patch_ids[0], s1_bands=S1Bands.ALL,
|
||||
s2_bands=[S2Bands.red, S2Bands.nir1], s2cloudy_bands=[S2Bands.red, S2Bands.nir1])
|
||||
print(
|
||||
f"Scene: {scene_id}, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}, Bounds: {bounds}")
|
||||
i += 1
|
||||
|
||||
print("Time Taken {}s".format(time.time() - start))
|
||||
print("\n")
|
||||
|
||||
start = time.time()
|
||||
# Load all bands of all patches in a specified scene (scene 106)
|
||||
s1, s2, s2cloudy, _ = sen12mscr.get_triplets(Seasons.SPRING, 106, s1_bands=S1Bands.ALL,
|
||||
s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL)
|
||||
|
||||
print(f"Scene: 106, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}")
|
||||
print("Time Taken {}s".format(time.time() - start))
|
||||
Reference in New Issue
Block a user