import logging import re import uuid from multiprocessing import Manager, Process, cpu_count, current_process from queue import Empty import boto3 import click import datacube from botocore import UNSIGNED from botocore.config import Config from datacube.index.hl import Doc2Dataset from datacube.utils import changes from osgeo import osr # Need to check if we're on new gdal for coordinate order import osgeo.gdal from packaging import version from ruamel.yaml import YAML GUARDIAN = "GUARDIAN_QUEUE_EMPTY" AWS_PDS_TXT_SUFFIX = "MTL.txt" LON_LAT_ORDER = version.parse(osgeo.gdal.__version__) < version.parse("3.0.0") MTL_PAIRS_RE = re.compile(r'(\w+)\s=\s(.*)') bands_ls8 = [('1', 'coastal_aerosol'), ('2', 'blue'), ('3', 'green'), ('4', 'red'), ('5', 'nir'), ('6', 'swir1'), ('7', 'swir2'), ('8', 'panchromatic'), ('9', 'cirrus'), ('10', 'lwir1'), ('11', 'lwir2'), ('QUALITY', 'quality')] bands_ls7 = [('1', 'blue'), ('2', 'green'), ('3', 'red'), ('4', 'nir'), ('5', 'swir1'), ('7', 'swir2'), ('QUALITY', 'quality')] def _parse_value(s): s = s.strip('"') for parser in [int, float]: try: return parser(s) except ValueError: pass return s def _parse_group(lines): tree = {} for line in lines: match = MTL_PAIRS_RE.findall(line) if match: key, value = match[0] if key == 'GROUP': tree[value] = _parse_group(lines) elif key == 'END_GROUP': break else: tree[key] = _parse_value(value) return tree def get_geo_ref_points(info): return { 'ul': {'x': info['CORNER_UL_PROJECTION_X_PRODUCT'], 'y': info['CORNER_UL_PROJECTION_Y_PRODUCT']}, 'ur': {'x': info['CORNER_UR_PROJECTION_X_PRODUCT'], 'y': info['CORNER_UR_PROJECTION_Y_PRODUCT']}, 'll': {'x': info['CORNER_LL_PROJECTION_X_PRODUCT'], 'y': info['CORNER_LL_PROJECTION_Y_PRODUCT']}, 'lr': {'x': info['CORNER_LR_PROJECTION_X_PRODUCT'], 'y': info['CORNER_LR_PROJECTION_Y_PRODUCT']}, } def get_coords(geo_ref_points, spatial_ref): t = osr.CoordinateTransformation(spatial_ref, spatial_ref.CloneGeogCS()) def transform(p): # GDAL 3 reverses coordinate order, because... standards if LON_LAT_ORDER: # GDAL 2.0 order lon, lat, z = t.TransformPoint(p['x'], p['y']) else: # GDAL 3.0 order lat, lon, z = t.TransformPoint(p['x'], p['y']) return {'lon': lon, 'lat': lat} return {key: transform(p) for key, p in geo_ref_points.items()} def satellite_ref(sat): """ To load the band_names for referencing either LANDSAT8 or LANDSAT7 bands """ if sat == 'LANDSAT_8': sat_img = bands_ls8 elif sat == 'LANDSAT_7' or sat == 'LANDSAT_5': sat_img = bands_ls7 else: raise ValueError('Satellite data Not Supported') return sat_img def absolutify_paths(doc, bucket_name, obj_key): objt_key = format_obj_key(obj_key) for band in doc['image']['bands'].values(): band['path'] = get_s3_url(bucket_name, objt_key + '/' + band['path']) return doc def make_metadata_doc(mtl_data, bucket_name, object_key): mtl_product_info = mtl_data['PRODUCT_METADATA'] mtl_metadata_info = mtl_data['METADATA_FILE_INFO'] satellite = mtl_product_info['SPACECRAFT_ID'] instrument = mtl_product_info['SENSOR_ID'] acquisition_date = mtl_product_info['DATE_ACQUIRED'] scene_center_time = mtl_product_info['SCENE_CENTER_TIME'] level = mtl_product_info['DATA_TYPE'] product_type = 'L1TP' sensing_time = acquisition_date + ' ' + scene_center_time cs_code = 32600 + mtl_data['PROJECTION_PARAMETERS']['UTM_ZONE'] label = mtl_metadata_info['LANDSAT_SCENE_ID'] spatial_ref = osr.SpatialReference() spatial_ref.ImportFromEPSG(cs_code) geo_ref_points = get_geo_ref_points(mtl_product_info) coordinates = get_coords(geo_ref_points, spatial_ref) bands = satellite_ref(satellite) doc = { 'id': str(uuid.uuid5(uuid.NAMESPACE_URL, get_s3_url(bucket_name, object_key))), 'processing_level': level, 'product_type': product_type, 'creation_dt': str(acquisition_date), 'label': label, 'platform': {'code': satellite}, 'instrument': {'name': instrument}, 'extent': { 'from_dt': sensing_time, 'to_dt': sensing_time, 'center_dt': sensing_time, 'coord': coordinates, }, 'format': {'name': 'GeoTiff'}, 'grid_spatial': { 'projection': { 'geo_ref_points': geo_ref_points, 'spatial_reference': 'EPSG:%s' % cs_code, } }, 'image': { 'bands': { band[1]: { 'path': mtl_product_info['FILE_NAME_BAND_' + band[0]], 'layer': 1, } for band in bands } }, 'lineage': {'source_datasets': {}}, } doc = absolutify_paths(doc, bucket_name, object_key) return doc def format_obj_key(obj_key): obj_key = '/'.join(obj_key.split("/")[:-1]) return obj_key def get_s3_url(bucket_name, obj_key): return 's3://{bucket_name}/{obj_key}'.format( bucket_name=bucket_name, obj_key=obj_key) def archive_document(doc, uri, index, sources_policy): def get_ids(dataset): ds = index.datasets.get(dataset.id, include_sources=True) for source in ds.sources.values(): yield source.id yield dataset.id resolver = Doc2Dataset(index) dataset, _ = resolver(doc, uri) index.datasets.archive(get_ids(dataset)) logging.info("Archiving %s and all sources of %s", dataset.id, dataset.id) def add_dataset(doc, uri, index, sources_policy): logging.info("Indexing %s", uri) resolver = Doc2Dataset(index) dataset, err = resolver(doc, uri) if err is not None: logging.error("%s", err) else: try: index.datasets.add(dataset, sources_policy=sources_policy) # Source policy to be checked in sentinel 2 datase types except changes.DocumentMismatchError: index.datasets.update(dataset, {tuple(): changes.allow_any}) except Exception as e: err = e logging.error("Unhandled exception %s", e) return dataset, err def worker(config, bucket_name, prefix, suffix, start_date, end_date, func, unsafe, sources_policy, queue): dc = datacube.Datacube(config=config) index = dc.index s3 = boto3.resource("s3", config=Config(signature_version=UNSIGNED)) safety = 'safe' if not unsafe else 'unsafe' while True: try: key = queue.get(timeout=60) if key == GUARDIAN: break logging.info("Processing %s %s", key, current_process()) obj = s3.Object(bucket_name, key).get() raw = obj['Body'].read() if suffix == AWS_PDS_TXT_SUFFIX: # Attempt to process text document raw_string = raw.decode('utf8') txt_doc = _parse_group(iter(raw_string.split("\n")))['L1_METADATA_FILE'] data = make_metadata_doc(txt_doc, bucket_name, key) else: yaml = YAML(typ=safety, pure=False) yaml.default_flow_style = False data = yaml.load(raw) uri = get_s3_url(bucket_name, key) cdt = data['creation_dt'] # Use the fact lexicographical ordering matches the chronological ordering if cdt >= start_date and cdt < end_date: logging.info("calling %s", func) func(data, uri, index, sources_policy) queue.task_done() except Empty: break except EOFError: break def iterate_datasets(bucket_name, config, prefix, suffix, start_date, end_date, func, unsafe, sources_policy): manager = Manager() queue = manager.Queue() s3 = boto3.resource('s3', config=Config(signature_version=UNSIGNED)) bucket = s3.Bucket(bucket_name) logging.info("Bucket : %s prefix: %s ", bucket_name, str(prefix)) # safety = 'safe' if not unsafe else 'unsafe' worker_count = cpu_count() * 2 processess = [] for i in range(worker_count): proc = Process(target=worker, args=(config, bucket_name, prefix, suffix, start_date, end_date, func, unsafe, sources_policy, queue,)) processess.append(proc) proc.start() for obj in bucket.objects.filter(Prefix=str(prefix)): if (obj.key.endswith(suffix)): queue.put(obj.key) for i in range(worker_count): queue.put(GUARDIAN) for proc in processess: proc.join() @click.command(help="Enter Bucket name. Optional to enter configuration file to access a different database") @click.argument('bucket_name') @click.option( '--config', '-c', help="Pass the configuration file to access the database", type=click.Path(exists=True) ) @click.option('--prefix', '-p', help="Pass the prefix of the object to the bucket") @click.option('--suffix', '-s', default=".yaml", help="Defines the suffix of the metadata_docs that will be used to load datasets. For AWS PDS bucket use MTL.txt") @click.option('--start_date', help="Pass the start acquisition date, in YYYY-MM-DD format") @click.option('--end_date', help="Pass the end acquisition date, in YYYY-MM-DD format") @click.option('--archive', is_flag=True, help="If true, datasets found in the specified bucket and prefix will be archived") @click.option('--unsafe', is_flag=True, help="If true, YAML will be parsed unsafely. Only use on trusted datasets. Only valid if suffix is yaml") @click.option('--sources_policy', default="verify", help="verify, ensure, skip") def main(bucket_name, config, prefix, suffix, start_date, end_date, archive, unsafe, sources_policy): logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO) action = archive_document if archive else add_dataset iterate_datasets(bucket_name, config, prefix, suffix, start_date, end_date, action, unsafe, sources_policy) if __name__ == "__main__": main()