import os import numpy as np from osgeo import gdal from buzzard._a_stored_raster import ABackStoredRaster from buzzard._tools import conv, GDALErrorCatcher from buzzard import _tools class ABackGDALRaster(ABackStoredRaster): """Abstract class defining the common implementation of all GDAL rasters""" # get_data implementation ******************************************************************* ** def get_data(self, fp, channel_ids, dst_nodata, interpolation): samplefp = self.build_sampling_footprint(fp, interpolation) if samplefp is None: return np.full( np.r_[fp.shape, len(channel_ids)], dst_nodata, self.dtype ) with self.acquire_driver_object() as gdal_ds: array = self.sample_bands_driver(samplefp, channel_ids, gdal_ds) array = self.remap( samplefp, fp, array=array, mask=None, src_nodata=self.nodata, dst_nodata=dst_nodata, mask_mode='erode', interpolation=interpolation, ) array = array.astype(self.dtype, copy=False) return array def sample_bands_driver(self, fp, channel_ids, gdal_ds): rtlx, rtly = self.fp.spatial_to_raster(fp.tl) assert rtlx >= 0 and rtlx < self.fp.rsizex, '{} >= 0 and {} < {}'.format(rtlx, rtlx, self.fp.rsizex) assert rtly >= 0 and rtly < self.fp.rsizey, '{} >= 0 and {} < {}'.format(rtly, rtly, self.fp.rsizey) dstarray = np.empty(np.r_[fp.shape, len(channel_ids)], self.dtype) for i, channel_id in enumerate(channel_ids): gdal_band = gdal_ds.GetRasterBand(channel_id + 1) success, payload = GDALErrorCatcher(gdal_band.ReadAsArray, none_is_error=True)( int(rtlx), int(rtly), int(fp.rsizex), int(fp.rsizey), buf_obj=dstarray[..., i], ) if not success: # pragma: no cover raise ValueError('Could not read array (gdal error: `{}`)'.format( payload[1] )) return dstarray # set_data implementation ******************************************************************* ** def set_data(self, array, fp, channel_ids, interpolation, mask): if not fp.share_area(self.fp): return if not fp.same_grid(self.fp) and mask is None: mask = np.ones(fp.shape, bool) dstfp = self.fp.intersection(fp) # Remap **************************************************************** ret = self.remap( fp, dstfp, array=array, mask=mask, src_nodata=self.nodata, dst_nodata=self.nodata or 0, mask_mode='erode', interpolation=interpolation, ) if mask is not None: array, mask = ret else: array = ret del ret array = array.astype(self.dtype, copy=False) fp = dstfp del dstfp # Write **************************************************************** # TODO: Close all but 1 driver? Or let user do this with self.acquire_driver_object() as gdal_ds: for i, channel_id in enumerate(channel_ids): leftx, topy = self.fp.spatial_to_raster(fp.tl) gdalband = gdal_ds.GetRasterBand(channel_id + 1) for sl in _tools.slices_of_matrix(mask): a = array[:, :, i][sl] assert a.ndim == 2 x = int(sl[1].start + leftx) y = int(sl[0].start + topy) assert x >= 0 assert y >= 0 assert x + a.shape[1] <= self.fp.rsizex assert y + a.shape[0] <= self.fp.rsizey gdalband.WriteArray(a, x, y) # fill implementation *********************************************************************** ** def fill(self, value, channel_ids): with self.acquire_driver_object() as gdal_ds: for gdalband in [gdal_ds.GetRasterBand(channel_id + 1) for channel_id in channel_ids]: gdalband.Fill(value) # Misc ************************************************************************************** ** def acquire_driver_object(self): # pragma: no cover raise NotImplementedError('ABackGDALRaster.acquire_driver_object is virtual pure') @classmethod def create_file(cls, path, fp, dtype, channel_count, channels_schema, driver, options, wkt, ow): """Create a raster dataset""" # Step 0 - Find driver ********************************************** ** success, payload = GDALErrorCatcher(gdal.GetDriverByName, none_is_error=True)(driver) if not success: raise ValueError('Could not find a driver named `{}` (gdal error: `{}`)'.format( driver, payload[1] )) dr = payload # Step 1 - Overwrite ************************************************ ** if dr.ShortName != 'MEM' and os.path.exists(path): if ow: success, payload = GDALErrorCatcher(dr.Delete, nonzero_int_is_error=True)(path) if not success: msg = 'Could not delete `{}` using driver `{}` (gdal error: `{}`)'.format( path, dr.ShortName, payload[1] ) raise RuntimeError(msg) else: msg = "Can't create `{}` with `ow=False` (overwrite) because file exist".format( path, ) raise RuntimeError(msg) # Step 2 - Create gdal_ds ******************************************* ** options = [str(arg) for arg in options] success, payload = GDALErrorCatcher(dr.Create)( path, fp.rsizex, fp.rsizey, channel_count, conv.gdt_of_any_equiv(dtype), options ) if not success: # pragma: no cover raise RuntimeError('Could not create `{}` using driver `{}` (gdal error: `{}`)'.format( path, dr.ShortName, payload[1] )) gdal_ds = payload # Step 3 - Set spatial reference ************************************ ** if wkt is not None: gdal_ds.SetProjection(wkt) gdal_ds.SetGeoTransform(fp.gt) # Step 4 - Set channels schema ************************************** ** channels_schema = _tools.sanitize_channels_schema(channels_schema, channel_count) cls._apply_channels_schema(gdal_ds, channels_schema) gdal_ds.FlushCache() return gdal_ds @staticmethod def _apply_channels_schema(gdal_ds, channels_schema): """Used on file creation""" if 'nodata' in channels_schema: for i, val in enumerate(channels_schema['nodata'], 1): if val is not None: gdal_ds.GetRasterBand(i).SetNoDataValue(val) if 'interpretation' in channels_schema: for i, val in enumerate(channels_schema['interpretation'], 1): val = conv.gci_of_str(val) gdal_ds.GetRasterBand(i).SetColorInterpretation(val) if 'offset' in channels_schema: for i, val in enumerate(channels_schema['offset'], 1): gdal_ds.GetRasterBand(i).SetOffset(val) if 'scale' in channels_schema: for i, val in enumerate(channels_schema['scale'], 1): gdal_ds.GetRasterBand(i).SetScale(val) if 'mask' in channels_schema: shared_bit = conv.gmf_of_str('per_dataset') for i, val in enumerate(channels_schema['mask'], 1): val = conv.gmf_of_str(val) if val & shared_bit: gdal_ds.CreateMaskBand(val) break for i, val in enumerate(channels_schema['mask'], 1): val = conv.gmf_of_str(val) if not val & shared_bit: gdal_ds.GetRasterBand(i).CreateMaskBand(val) @staticmethod def _channels_schema_of_gdal_ds(gdal_ds): """Used on file opening""" bands = [gdal_ds.GetRasterBand(i + 1) for i in range(gdal_ds.RasterCount)] return { 'nodata': [band.GetNoDataValue() for band in bands], 'interpretation': [conv.str_of_gci(band.GetColorInterpretation()) for band in bands], 'offset': [band.GetOffset() if band.GetOffset() is not None else 0. for band in bands], 'scale': [band.GetScale() if band.GetScale() is not None else 1. for band in bands], 'mask': [conv.str_of_gmf(band.GetMaskFlags()) for band in bands], }