import logging
import numpy

from collections import namedtuple
from osgeo import gdal, gdal_array

A wrapper for a geospatial image

- Bands: A list of uint16 numpy arrays, each holding a band of data
- Alpha: A boolean numpy array holding the alpha information
         - False is a no data pixel
         - True is a valid pixel
- Metadata: A dict containing georeferencing information
            - geotransform, projection and rpc
GImage = namedtuple('GImage', 'bands, alpha, metadata')

def save(gimage, filename, nodata=None, compress=True):
    band_count = len(gimage.bands) + 1
    ysize, xsize = gimage.bands[0].shape
    gdal_ds = create_ds(filename, xsize, ysize, band_count, compress)
    _save_to_ds(gimage, gdal_ds, nodata)

def create_ds(file_name, xsize, ysize, band_count, compress=True):
    options = ['PHOTOMETRIC=RGB']
    if compress:

    datatype = gdal.GDT_UInt16
    gdal_ds = gdal.GetDriverByName('GTIFF').Create(
        file_name, xsize, ysize, band_count, datatype,
    return gdal_ds

def _save_to_ds(gimage, gdal_ds, nodata=None):
    assert gdal_ds.RasterCount == len(gimage.bands) + 1
    assert gdal_ds.RasterXSize == gimage.bands[0].shape[1]
    assert gdal_ds.RasterYSize == gimage.bands[0].shape[0]

    # Image bands
    for i, band in enumerate(gimage.bands):
        save_band(gdal_ds, band, i + 1, nodata)
    save_alpha_band(gdal_ds, gimage.alpha)
    save_metadata(gdal_ds, gimage.metadata)

def save_band(gdal_ds, band_array, band_no, nodata=None):
    gdal_band = gdal_ds.GetRasterBand(band_no)
    gdal_array.BandWriteArray(gdal_band, band_array)
    if nodata is not None:

def save_alpha_band(gdal_ds, alpha_array):
    alpha_band = gdal_ds.GetRasterBand(gdal_ds.RasterCount)
                              alpha_array.astype(numpy.uint16) * 255)

def save_metadata(gdal_ds, metadata):
    # Save georeferencing information
    if 'projection' in metadata.keys():
    if 'geotransform' in metadata.keys():
    if 'rpc' in metadata.keys():
        gdal_ds.SetMetadata(metadata['rpc'], 'RPC')

def load(filename, nodata=None, last_band_alpha=False):
    logging.info('GImage: Loading {} as GImage'.format(filename))
    gdal_ds = gdal.Open(filename)
    if gdal_ds is None:
        raise Exception('Unable to open file "{}" with gdal.Open()'.format(

    alpha, band_count = read_alpha_and_band_count(gdal_ds, last_band_alpha)
    bands = _read_all_bands(gdal_ds, band_count)
    metadata = read_metadata(gdal_ds)

    if nodata:
        alpha = alpha * _nodata_to_mask(bands, nodata)
    return GImage(bands, alpha, metadata)

def read_metadata(gdal_ds):
    metadata = {}

    default_geotransform = (-1.0, 1.0, 0.0, 1.0, 0.0, -1.0)
    geotransform = gdal_ds.GetGeoTransform()
    if geotransform == default_geotransform:
        logging.debug('GImage: Raster has default geotransform, not storing')
        metadata['geotransform'] = geotransform

    projection = gdal_ds.GetProjection()
    if projection == '':
            'GImage: Raster has no projection information, not storing')
        metadata['projection'] = gdal_ds.GetProjection()

    rpc = gdal_ds.GetMetadata('RPC')
    if rpc == {}:
        logging.debug('GImage: Raster has no rpc information, not storing')
        metadata['rpc'] = rpc
    return metadata

def _read_all_bands(gdal_ds, band_count):
    bands = []
    for band_n in range(1, band_count + 1):
        bands.append(read_single_band(gdal_ds, band_n))
    return bands

def read_single_band(gdal_ds, band_no):
    ''' band_no is gdal style band numbering, i.e. from 1 onwards not 0 indexed
    band = gdal_ds.GetRasterBand(band_no)
    array = band.ReadAsArray()
    if array is None:
        raise Exception(
            'GDAL error occured : {}'.format(gdal.GetLastErrorMsg()))
    return array.astype(numpy.uint16)

def read_alpha_and_band_count(gdal_ds, last_band_alpha=False):
    logging.info('GImage: Initial band count: {}'.format(
    last_band = gdal_ds.GetRasterBand(gdal_ds.RasterCount)
    if last_band.GetColorInterpretation() == gdal.GCI_AlphaBand:
        logging.info('GImage: Alpha band found, reducing band count')
        alpha = last_band.ReadAsArray().astype(numpy.bool)
        band_count = gdal_ds.RasterCount - 1
    elif last_band_alpha:
            'GImage: Forcing last band to be an alpha band, reducing band '
        alpha = last_band.ReadAsArray().astype(numpy.bool)
        band_count = gdal_ds.RasterCount - 1
        logging.info('GImage: No alpha band found')
        alpha = numpy.ones(
            (gdal_ds.RasterYSize, gdal_ds.RasterXSize),
        band_count = gdal_ds.RasterCount
    return alpha, band_count

def _nodata_to_mask(bands, nodata):
    alpha = numpy.ones(bands[0].shape, dtype=numpy.uint16)
    for band in bands:
        alpha[band == nodata] = 0
    return alpha

def check_comparable(gimages, check_metadata=False):
    '''Checks that the gimages have the same number of bands, band dimensions,
    and, optionally, geospatial metadata'''

    no_bands = len(gimages[0].bands)
    band_shape = gimages[0].bands[0].shape
    metadata = gimages[0].metadata

    logging.debug('GImage: Initial image - band number, band shape: '
                  '{}, {}'.format(no_bands, band_shape))
    logging.debug('GImage: Initial image metadata: '.format(metadata))

    for i, image in enumerate(gimages[1:]):
        if len(image.bands) != no_bands:
            raise Exception(
                'Image {} has a different number of bands: '
                '{} (initial: {})'.format(i + 1, len(image.bands), no_bands))

        if image.bands[0].shape != band_shape:
            raise Exception(
                'Image {} has a different band shape: {} (initial: {})'.format(
                    i + 1, image.bands[0].shape, band_shape))

        if check_metadata and image.metadata != metadata:
            raise Exception(
                'Image {} has different geographic metadata: {} '
                '(initial: {})'.format(i + 1, image.metadata, metadata))

def check_equal(gimages, check_metadata=False):
    '''Checks that a list of gimages are equivalent'''

    check_comparable(gimages, check_metadata)

    first_gimg = gimages[0]
    for i, image in enumerate(gimages[1:]):
        numpy.testing.assert_equal(first_gimg.bands, image.bands,
                                   err_msg='Image {} has different band data'
                                   ' to the first image'.format(i))

        numpy.testing.assert_equal(first_gimg.alpha, image.alpha,
                                   err_msg='Image {} has different alpha data'
                                   ' to the first image'.format(i))