## These functions are tools to deal with rasters
## SMM
## 26/07/2014
#from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import absolute_import, division, print_function

import osgeo.gdal as gdal
import osgeo.gdal_array as gdal_array
import numpy as np
from osgeo import osr
from osgeo import ogr
import os
from os.path import exists
from osgeo.gdalconst import GA_ReadOnly

def getNoDataValue(rasterfn):
    """This gets the nodata value from the raster

        rasterfn (str): The filename (with path and extension) of the raster

        float: nodatavalue; the nodata value

    Author: SMM
    raster = gdal.Open(rasterfn)
    band = raster.GetRasterBand(1)
    return band.GetNoDataValue()

def setNoDataValue(rasterfn):
    """This sets the nodata value from the raster

        rasterfn (str): The filename (with path and extension) of the raster


    Author: SMM

    raster = gdal.Open(rasterfn)
    band = raster.GetRasterBand(1)
    return band.SetNoDataValue()

def GetUTMMaxMin(FileName):
    """This gets the minimum and maximum UTM values.

    *WARNING* it assumes raster is already projected into UTM, and is in ENVI format! It reads from an ENVI header file.

        FileName (str): The filename (with path and extension) of the raster

        float: The cell size in metres
        float: The X minimum (easting) in metres
        float: The X maximum (easting) in metres
        float: The Y minimum (northing) in metres
        float: The Y maximum (northing) in metres

    Author: SMM

    if exists(FileName) is False:
            raise Exception('[Errno 2] No such file or directory: \'' + FileName + '\'')

    NDV, xsize, ysize, GeoT, Projection, DataType = GetGeoInfo(FileName)
    CellSize = GeoT[1]
    XMin = GeoT[0]
    XMax = XMin+CellSize*xsize

    YMax = GeoT[3]
    YMin = YMax-CellSize*ysize

    return CellSize,XMin,XMax,YMin,YMax

# Gets the pixel area, assumes units are projected
def GetPixelArea(FileName):
    """Gets the area in m^2 of the pixels

        rasterfn (str): The filename (with path and extension) of the raster

        float: Pixel_area (float): The area of each pixel

    Author: SMM

    if exists(FileName) is False:
            raise Exception('[Errno 2] No such file or directory: \'' + FileName + '\'')

    NDV, xsize, ysize, GeoT, Projection, DataType = GetGeoInfo(FileName)
    CellSize = GeoT[1]

    return CellSize*CellSize

# this takes rows and columns of minium and maximum values and converts them
# to UTM
def GetUTMMaxMinFromRowsCol(FileName,x_max_col,x_min_col,y_max_row,y_min_row):
    """This gets the minimum and maximum UTM values but you give it the row and column numbers.

        This assumes raster is already projected into UTM, and is in ENVI format! It reads from an ENVI header file.

        FileName (str): The filename (with path and extension) of the raster
        x_max_col (int): The column to use as the maximum
        x_min_col (int): The column to use as the minimum
        y_max_row (int): The row to use as the maximum
        y_min_row (int): The row to use as the minimum


        float: The X maximum (easting) in metres
        float: The X minimum (easting) in metres
        float: The Y maximum (northing) in metres
        float: The Y minimum (northing) in metres

    Author: SMM

    if exists(FileName) is False:
            raise Exception('[Errno 2] No such file or directory: \'' + FileName + '\'')

    NDV, xsize, ysize, GeoT, Projection, DataType = GetGeoInfo(FileName)
    CellSize = GeoT[1]
    XMin = GeoT[0]

    YMax = GeoT[3]
    YMin = YMax-CellSize*ysize

    xmax_UTM = XMin+x_max_col*CellSize
    xmin_UTM = XMin+x_min_col*CellSize

    # need to be careful with the ymax_UTM since the rows go from the top
    # but the header index is to bottom corner

    print("yll: "+str(YMin)+" and nrows: " +str(ysize) + " dx: "+str(CellSize))

    ymax_from_bottom = ysize-y_min_row
    ymin_from_bottom = ysize-y_max_row
    ymax_UTM = YMin+ymax_from_bottom*CellSize
    ymin_UTM = YMin+ymin_from_bottom*CellSize

    return xmax_UTM,xmin_UTM,ymax_UTM,ymin_UTM

# This gets the x and y vectors of the data
def GetLocationVectors(FileName):
    """This gets a vector of the x and y locations of the coordinates

        This assumes raster is already projected into UTM, and is in ENVI format! It reads from an ENVI header file.

        FileName (str): The filename (with path and extension) of the raster.

        float: A vector of the x locations (eastings)
        float: A vector of the y locations (northings)

    Author: SMM

    NDV, xsize, ysize, GeoT, Projection, DataType = GetGeoInfo(FileName)

    CellSize,XMin,XMax,YMin,YMax = GetUTMMaxMin(FileName)

    x_vec = np.arange(XMin,XMax,CellSize)
    y_vec = np.arange(YMin,YMax,CellSize)

    return x_vec,y_vec

# This gets the extent of the raster
def GetRasterExtent(FileName):
    """This gets a vector of the minimums and maximums of the coordinates

        This assumes raster is already projected into UTM, and is in ENVI format! It reads from an ENVI header file.

        FileName (str): The filename (with path and extension) of the raster.

        float: A vector that contains

            * extent[0]: XMin
            * extent[1]: XMax
            * extent[2]: YMin
            * extent[3]: YMax

    Author: SMM
    CellSize,XMin,XMax,YMin,YMax = GetUTMMaxMin(FileName)
    extent = [XMin,XMax,YMin,YMax]
    return extent

# Function to read the original file's projection:
def GetGeoInfo(FileName):
    """This gets information from the raster file using gdal

        FileName (str): The filename (with path and extension) of the raster.

        float: A vector that contains:
            * NDV: the nodata values
            * xsize: cellsize in x direction
            * ysize: cellsize in y direction
            * GeoT: the tranform (a string)
            * Projection: the Projection (a string)
            * DataType: The type of data (an int explaing the bits of each data element)

    Author: SMM

    if exists(FileName) is False:
            raise Exception('[Errno 2] No such file or directory: \'' + FileName + '\'')

    SourceDS = gdal.Open(FileName, gdal.GA_ReadOnly)
    if SourceDS == None:
        raise Exception("Unable to read the data file")

    NDV = SourceDS.GetRasterBand(1).GetNoDataValue()
    xsize = SourceDS.RasterXSize
    ysize = SourceDS.RasterYSize
    GeoT = SourceDS.GetGeoTransform()
    Projection = osr.SpatialReference()
    DataType = SourceDS.GetRasterBand(1).DataType
    DataType = gdal.GetDataTypeName(DataType)

    return NDV, xsize, ysize, GeoT, Projection, DataType

# This gets the UTM zone, if it exists
def GetUTMEPSG(FileName):
    """Uses GDAL to get the EPSG string from the raster.

        FileName (str): The filename (with path and extension) of the raster.

        str: The EPSG string

    Author: SMM
    if exists(FileName) is False:
            raise Exception('[Errno 2] No such file or directory: \'' + FileName + '\'')

    # see if the file exists and get the dataset
    SourceDS = gdal.Open(FileName, gdal.GA_ReadOnly)
    if SourceDS == None:
        raise Exception("Unable to read the data file")

    EPSG_string = 'NULL'

    # get the projection
    print("Let me get that projection for you")
    print("In this function I will extract the UTM zone")

    if srs.IsProjected:
        print("The dataset is projected.")
        #print("Trying projcs")
        proj_str = srs.GetAttrValue(str('projcs'))
        print("The projection string is: "+proj_str)


        if proj_str != None:

            # extract the UTM information
            if "UTM Zone" in proj_str:
                print("Found the string UTM zone")

                first_split = proj_str.split(',')
                first_half = first_split[0]
                second_half = first_split[1]
                if "Northern" in second_half:
                    N_or_S = "N"
                    N_or_S = "S"
                second_split = first_half.split(' ')
                zone = second_split[2]
            elif "UTM zone" in proj_str:
                print("This seems to be from the new gdal version")
                first_split = proj_str.split(' ')
                zone_str = first_split[-1]
                print("Zone string is: "+zone_str)
                zone = zone_str[:-1]
                print("The zone is: "+zone)

                if zone_str[-1]=="S":
                    N_or_S = "S"
                    N_or_S = "N"    

                print("And the hemisphere is: "+N_or_S)

            elif "_Hemisphere" in proj_str:
                if "Southern" in proj_str:
                    N_or_S = "S"
                    N_or_S = "N"
                proj_split = proj_str.split('_')
                zone = proj_split[2]

                proj_split = proj_str.split('_')
                zone = proj_split[-1]

                N_or_S = zone[-1]
                zone = zone[:-1]

            # adding some logic for zones < 10
            if len(zone) < 2:
                zone = '0'+zone
            EPSG_string = 'epsg:'
            if N_or_S == 'S':
                EPSG_string = EPSG_string+'327'+zone
                EPSG_string = EPSG_string+'326'+zone
            print("The EPSG string is: "+EPSG_string)
        raise Exception("This is not a projected coordinate system!")

    return EPSG_string

# Function to read the original file's projection:
def GetNPixelsInRaster(FileName):
    """This gets the total number of pixels in the raster

        FileName (str): The filename (with path and extension) of the raster.

        int: The total number of pixels

    Author: SMM

    NDV, xsize, ysize, GeoT, Projection, DataType = GetGeoInfo(FileName)

    return xsize*ysize


# Function to read the original file's projection:
def CheckNoData(FileName):
    """This looks through the head file of an ENVI raster and if it doesn't find the nodata line it rewrites the file to include the nodata line.

        FileName (str): The filename (with path and extension) of the raster.

        int: The total number of pixels (although what it is really doing is updating the header file. The return is just to check if it is working and yes I know this is stupid. )

    Author: SMM

    if exists(FileName) is False:
        raise Exception('[Errno 2] No such file or directory: \'' + FileName + '\'')

    # read the file, and check if there is a no data value
    SourceDS = gdal.Open(FileName, gdal.GA_ReadOnly)
    if SourceDS == None:
        raise Exception("Unable to read the data file")
    NoDataValue = SourceDS.GetRasterBand(1).GetNoDataValue()

    print("In the check nodata routine. Nodata is: ")

    if NoDataValue == None:
        print("This raster does not have no data. Updating the header file")
        header_name = FileName[:-4]
        header_name = header_name+".hdr"

        # read the header
        if exists(header_name) is False:
            raise Exception('[Errno 2] No such file or directory: \'' + header_name + '\'')
            this_file = open(header_name, 'r')
            lines = this_file.readlines()
            lines.append("data ignore value = -9999")

            this_file = open(header_name, 'w')
            for item in lines:
                this_file.write("%s" % item)        # no newline since a newline command character comes with the lines


    NDV, xsize, ysize, GeoT, Projection, DataType = GetGeoInfo(FileName)

    return xsize*ysize


def ReadRasterArrayBlocks(raster_file,raster_band=1):
    """This reads a raster file (from GDAL) into an array. The "blocks" bit makes it efficient.
        FileName (str): The filename (with path and extension) of the raster.
        raster_band (int): the band of the raster (almost all uses with LSDTopoTools will have a 1 band raster)

        np.array: A numpy array with the data from the raster.

    Author: SMM

    if exists(raster_file) is False:
            raise Exception('[Errno 2] No such file or directory: \'' + raster_file + '\'')

    dataset = gdal.Open(raster_file, GA_ReadOnly )
    if dataset == None:
        raise Exception("Unable to read the data file")

    band = dataset.GetRasterBand(raster_band)
    NoDataValue = dataset.GetRasterBand(1).GetNoDataValue()

    block_sizes = band.GetBlockSize()
    x_block_size = block_sizes[0]
    y_block_size = block_sizes[1]

    #If the block y size is 1, as in a GeoTIFF image, the gradient can't be calculated,
    #so more than one block is used. In this case, using8 lines gives a similar
    #result as taking the whole array.
    if y_block_size < 8:
        y_block_size = 8

    xsize = band.XSize
    ysize = band.YSize

    print("xsize: " +str(xsize)+" and y size: " + str(ysize))

    max_value = band.GetMaximum()
    min_value = band.GetMinimum()

    # now initiate the array
    data_array = np.zeros((ysize,xsize))

    #print "data shape is: "
    #print data_array.shape

    if max_value == None or min_value == None:
        stats = band.GetStatistics(0, 1)
        max_value = stats[1]
        min_value = stats[0]

    for i in range(0, ysize, y_block_size):
        if i + y_block_size < ysize:
            rows = y_block_size
            rows = ysize - i

        for j in range(0, xsize, x_block_size):
            if j + x_block_size < xsize:
                cols = x_block_size
                cols = xsize - j

            # get the values for this block
            values = band.ReadAsArray(j, i, cols, rows)

            # move these values to the data array
            data_array[i:i+rows,j:j+cols] = values

    print("NoData is:", NoDataValue)
    if NoDataValue is not None:
        nodata_mask = data_array == NoDataValue
        data_array[nodata_mask] = np.nan

    return data_array

def ReadRasterArrayBlocks_numpy(raster_file,raster_band=1):
    This reads a raster file into an array using numpy. The "blocks" bit makes it efficient.
        FileName (str): The filename (with path and extension) of the raster.
        raster_band (int): the band of the raster (almost all uses with LSDTopoTools will have a 1 band raster) Doesn't work yet for more than one band.

        np.array: A numpy array with the data from the raster.

    Author: SMM

    print("I will now use numpy.fromfile to load your raster, I still need to be tested, If something goes wrong switch back to the classic method by switching NFF_opti = False. Classic method = unefficient.")
    if exists(raster_file) is False:
            raise Exception('[Errno 2] No such file or directory: \'' + raster_file + '\'')

    with open(raster_file[:-3]+"hdr","r") as hdr_file:
        print("I am opening your raster info")
        no_data_hdr = -9999 # setting the default noData in case there is no value setted in the hdr file
        for line in hdr_file:
            #testing the data type
            if(line[0:9] == "data type"):
                info_dtype = line[-2:]
                info_dtype = int(info_dtype)
                if(line[0:8] == "map info"):
                    info = line[12:-2]
                    info = info.split(",")
                    x_min = float(info[3])
                    y_max = float(info[4])
                    x_res = float(info[5])
                    y_res = float(info[6])
                    utm_zone = int(info[7])
                    utm_hemisphere = info[8]
                    if(line[0:7] == "samples"):
                        num_col = line.replace(" ","").split("=")[1]
                        print("there are " + str(num_col) + " columns")
                        num_col = int(num_col)
                        if(line[0:5] == "lines"):
                            num_lines = line.replace(" ","").split("=")[1]
                            print("there are " + str(num_lines) + " lines")
                            num_lines = int(num_lines)
                            if(line[0:17] == "data ignore value"):
                                no_data_hdr = line.replace(" ","").split("=")[1]
                                print("No data value is: " + str(no_data_hdr) )
                                no_data_hdr = float(no_data_hdr)

        #The type of data representation:
        #1 = Byte: 8-bit unsigned integer
        #2 = Integer: 16-bit signed integer
        #3 = Long: 32-bit signed integer
        #4 = Floating-point: 32-bit single-precision
        #5 = Double-precision: 64-bit double-precision floating-point
        #6 = Complex: Real-imaginary pair of single-precision floating-point
        #9 = Double-precision complex: Real-imaginary pair of double precision floating-point
        #12 = Unsigned integer: 16-bit
        #13 = Unsigned long integer: 32-bit
        #14 = 64-bit long integer (signed)
        #15 = 64-bit unsigned long integer (unsigned)

    if(info_dtype == 1):
        data_type = np.dtype('uint8')
        if(info_dtype == 2):
            data_type = np.dtype('int16')
            if(info_dtype == 3):
                data_type = np.dtype('int32')
                if(info_dtype == 4):
                    data_type = np.dtype('Float32')
                    if(info_dtype == 5):
                        data_type = np.dtype('Float64')
                        if(info_dtype == 12):
                            data_type = np.dtype('uint16')
    print("your data type is "  + str(data_type))
    # Alright now loading and converting the data
    print("I am now ingesting your raster")
    data_array = np.fromfile(raster_file,data_type).reshape(num_lines,num_col)
    if(info_dtype in [1,2,3,12]):
        data_array = data_array.astype(float)
    print("I nailed it")
    x_max = x_min + x_res*num_col
    y_min = y_max - y_res*num_lines
    print("I am returning the raster array and info")
    NoDataValue = no_data_hdr
    if NoDataValue is not None:
        data_array[data_array == NoDataValue] = np.nan

    return data_array

def array2raster(rasterfn,newRasterfn,array,driver_name = "ENVI", noDataValue = -9999):
    """Takes an array and writes to a GDAL compatible raster. It needs another raster to map the dimensions.

        FileName (str): The filename (with path and extension) of a raster that has the same dimensions as the raster to be written.
        newRasterfn (str): The filename (with path and extension) of the new raster.
        array (np.array): The array to be written
        driver_name (str): The type of raster to write. Default is ENVI since that is the LSDTOpoTools format
        noDataValue (float): The no data value

        np.array: A numpy array with the data from the raster.

    Author: SMM

    raster = gdal.Open(rasterfn)
    geotransform = raster.GetGeoTransform()
    originX = geotransform[0]
    originY = geotransform[3]
    pixelWidth = geotransform[1]
    pixelHeight = geotransform[5]
    cols = raster.RasterXSize
    rows = raster.RasterYSize

    driver = gdal.GetDriverByName(driver_name)
    outRaster = driver.Create(newRasterfn, cols, rows, 1, gdal.GDT_Float32)
    outRaster.SetGeoTransform((originX, pixelWidth, 0, originY, 0, pixelHeight))
    outRaster.GetRasterBand(1).SetNoDataValue( noDataValue )
    outband = outRaster.GetRasterBand(1)
    outRasterSRS = osr.SpatialReference()

def RasterDifference(RasterFile1, RasterFile2, raster_band=1, OutFileName="Test.outfile", OutFileType="ENVI"):
    Takes two rasters of same size and subtracts second from first,
    e.g. Raster1 - Raster2 = raster_of_difference
    then writes it out to file

    Raster1 = gdal.Open(RasterFile1)
    Raster2 = gdal.Open(RasterFile2)

    print("RASTER 1: ")

    print("RASTER 2: ")

    raster_array1 = np.array(Raster1.GetRasterBand(raster_band).ReadAsArray())
    raster_array2 = np.array(Raster2.GetRasterBand(raster_band).ReadAsArray())

    assert(raster_array1.shape == raster_array2.shape )
    print("Shapes: ", raster_array1.shape, raster_array2.shape)

    difference_raster_array = raster_array1 - raster_array2

#    import matplotlib.pyplot as plt
#    plt.imshow(difference_raster_array)

    driver = gdal.GetDriverByName(OutFileType)

    dsOut = driver.Create(OutFileName,

    bandOut = dsOut.GetRasterBand(1)
    gdal_array.BandWriteArray(bandOut, difference_raster_array)

def PolygoniseRaster(DataDirectory, RasterFile, OutputShapefile='polygons'):
    This function takes in a raster and converts to a polygon shapefile using rasterio
    from https://gis.stackexchange.com/questions/187877/how-to-polygonize-raster-to-shapely-polygons/187883#187883?newreg=8b1f507529724a8488ce4789ba787363

        DataDirectory (str): the data directory with the basin raster
        RasterFile (str): the name of the raster
        OutputShapefile (str): the name of the output shapefile WITHOUT EXTENSION. Default = 'polygons'

        Dictionary where key is the raster value and the value is a shapely polygon

    Author: FJC
    # import modules
    import rasterio
    from rasterio.features import shapes
    from shapely.geometry import shape, Polygon, mapping
    import fiona

    # define the mask
    #mask = None
    raster_band = 1

    # get raster no data value
    NDV = getNoDataValue(DataDirectory+RasterFile)

    # load in the raster using rasterio
    with rasterio.open(DataDirectory+RasterFile) as src:
        image = src.read(raster_band, masked=False)

        msk = src.read_masks(1)

        results = (
        {'properties': {'raster_val': v}, 'geometry': s}
        for i, (s, v)
        in enumerate(
            shapes(image, mask=msk, transform=src.transform)))

    # define shapefile attributes
    # crs = src.crs.wkt
    # print (crs)
    print("Let me grab the coordinate reference system.")
    crs = GetUTMEPSG(DataDirectory+RasterFile)
    schema = {'geometry': 'Polygon',
              'properties': { 'ID': 'float'}}

    # This is necessary to filter the basin results
    geoms = list(results)
    #print("Geom size is: "+str(len(geoms)))
    filtered_geoms = {}
    area_dict = {}
    for f in geoms:
        this_shape = Polygon(shape(f['geometry']))
        this_val = float(f['properties']['raster_val'])
        #print("ID is: "+str(this_val))
        this_area = this_shape.area
        if this_val in filtered_geoms.keys():
            print("Whoops. Found a repeated ID. Getting rid of the smaller one.")
            if area_dict[this_val] < this_area:
                filtered_geoms[this_val] = f
                area_dict[this_val] = this_area               
                print("Found a repeated ID. Keeping the one with area of "+str(this_area))
                print("Keeping the initial ID.")
            filtered_geoms[this_val] = f
            area_dict[this_val] = this_area
    new_geoms = []
    for key,item in filtered_geoms.items():
        this_shape = Polygon(shape(item['geometry']))
        this_val = float(item['properties']['raster_val'])
        #print("ID is: "+str(this_val)) 
        this_area = this_shape.area
        #print("Area is: "+str(this_area))
    #print("Geom size is: "+str(len(new_geoms)))
    # transform results into shapely geometries and write to shapefile using fiona
    PolygonDict = {}
    with fiona.open(DataDirectory+OutputShapefile, 'w', crs=crs, driver='ESRI Shapefile', schema=schema) as output:
        for f in new_geoms:
            this_shape = Polygon(shape(f['geometry']))
            this_val = float(f['properties']['raster_val'])
            print("ID is: "+str(this_val))
            if this_val != NDV: # remove no data values
                output.write({'geometry': mapping(this_shape), 'properties':{'ID': this_val}})
            PolygonDict[this_val] = this_shape

    return PolygonDict

def PolygoniseRasterMerge(DataDirectory, RasterFile, OutputShapefile='polygons'):
    This function takes in a raster and converts to a polygon shapefile using rasterio
    from https://gis.stackexchange.com/questions/187877/how-to-polygonize-raster-to-shapely-polygons/187883#187883?newreg=8b1f507529724a8488ce4789ba787363

    This version recognises where there are multiple polygons with the same key and merges
    them to a MultiPolygon using cascaded_union

        DataDirectory (str): the data directory with the basin raster
        RasterFile (str): the name of the raster
        OutputShapefile (str): the name of the output shapefile WITHOUT EXTENSION. Default = 'polygons'

        Dictionary where key is the raster value and the value is a shapely polygon

    Author: FJC
    # import modules
    import rasterio
    from rasterio.features import shapes
    from shapely.geometry import shape, Polygon, mapping
    from shapely.ops import cascaded_union
    import fiona

    # define the mask
    #mask = None
    raster_band = 1

    # get raster no data value
    NDV = getNoDataValue(DataDirectory+RasterFile)

    # load in the raster using rasterio
    with rasterio.open(DataDirectory+RasterFile) as src:
        image = src.read(raster_band, masked=False)

        msk = src.read_masks(1)

        results = (
        {'properties': {'raster_val': v}, 'geometry': s}
        for i, (s, v)
        in enumerate(
            shapes(image, mask=msk, transform=src.transform)))

    # define shapefile attributes
    # crs = src.crs.wkt
    # print (crs)
    crs = GetUTMEPSG(DataDirectory+RasterFile)
    schema = {'geometry': 'Polygon',
              'properties': { 'ID': 'float'}}

    # transform results into shapely geometries and write to shapefile using fiona
    geoms = list(results)
    PolygonDict = {}
    with fiona.open(DataDirectory+OutputShapefile, 'w', crs=crs, driver='ESRI Shapefile', schema=schema) as output:
        for f in geoms:
            this_shape = Polygon(shape(f['geometry']))
            this_val = float(f['properties']['raster_val'])
            if this_val in PolygonDict:
                Polygons = [this_shape, PolygonDict[this_val]]
                this_shape = cascaded_union(Polygons)
            if this_val != NDV: # remove no data values
                output.write({'geometry': mapping(this_shape), 'properties':{'ID': this_val}})

            PolygonDict[this_val] = this_shape

    return PolygonDict

def CreateShapefileOfRasterFootprint(DataDirectory, RasterFile):
    This function takes a raster and creates a shapefile that is the footprint
    of the raster. Used for plotting the raster footprint on regional maps using
    Variously put together from:
        DataDirectory (str): the data directory with the basin raster
        RasterFile (str): the name of the raster

        Shapefile of the raster footprint

    Author: SMM
    Date: 23/01/2018

    print("Trying to create a shapefile.")
    print("The Data directory is: "+DataDirectory+ " and the raster is: "+ RasterFile)
    driver_name = "ESRI shapefile"
    driver = ogr.GetDriverByName(driver_name)

    # get the filename of the outfile.
    if not DataDirectory.endswith(os.sep):
        print("You forgot the separator at the end of the directory, appending...")
        DataDirectory = DataDirectory+os.sep
    # Get the raster prefix
    SplitRasterfile = RasterFile.split(".")
    RasterPrefix = SplitRasterfile[0]
    # get the espg of the raster
    FullFilename = DataDirectory+RasterFile
    ESPG_this_raster = GetUTMEPSG(FullFilename)  
    ESPG_this_raster = str(ESPG_this_raster)
    print("The raster has coordinate of: "+ESPG_this_raster)
    ESPG_this_raster_split = ESPG_this_raster.split(":")
    ESPG_this_raster = ESPG_this_raster_split[-1] 
    print ("This ESPG is: "+str(ESPG_this_raster))
    # Get extent of raster
    [xmin,xmax,ymin,ymax] = GetRasterExtent(FullFilename)

    # Create ring
    ring = ogr.Geometry(ogr.wkbLinearRing)
    ring.AddPoint(xmin, ymin)
    ring.AddPoint(xmin, ymax)
    ring.AddPoint(xmax, ymax)
    ring.AddPoint(xmax, ymin)
    ring.AddPoint(xmin, ymin)

    # Create polygon
    poly = ogr.Geometry(ogr.wkbPolygon)
    # Create a coordinate transformation
    source = osr.SpatialReference()
    target = osr.SpatialReference()
    transform = osr.CoordinateTransformation(source, target)
    # now transformt the polygon
    # see what you got
    #print("The polygon is:")

    # create the data source
    OutFileName = DataDirectory+RasterPrefix+"_footprint.shp" 
    print("The output shapefile is: "+OutFileName)
    datasource = driver.CreateDataSource(OutFileName)    

    # create the layer
    layer = datasource.CreateLayer(OutFileName, target, ogr.wkbPolygon)
    feature = ogr.Feature(layer.GetLayerDefn())   
    # Clean up
def GetCentreAndExtentOfRaster(DataDirectory, RasterFile):
    This function takes a raster and returns the centrepoint and the extent in both degrees and metres. 
        DataDirectory (str): the data directory with the basin raster
        RasterFile (str): the name of the raster

        The lat-long of the centrepoint, the x-y- extent in both degrees and metres 

    Author: SMM
    Date: 01/02/2018

    print("Trying to create a shapefile.")
    print("The Data directory is: "+DataDirectory+ " and the raster is: "+ RasterFile)
    driver_name = "ESRI shapefile"
    driver = ogr.GetDriverByName(driver_name)

    # get the filename of the outfile.
    if not DataDirectory.endswith(os.sep):
        print("You forgot the separator at the end of the directory, appending...")
        DataDirectory = DataDirectory+os.sep
    # Get the raster prefix
    SplitRasterfile = RasterFile.split(".")
    RasterPrefix = SplitRasterfile[0]
    # get the espg of the raster
    FullFilename = DataDirectory+RasterFile
    ESPG_this_raster = GetUTMEPSG(FullFilename)  
    ESPG_this_raster = str(ESPG_this_raster)
    print("The raster has coordinate of: "+ESPG_this_raster)
    ESPG_this_raster_split = ESPG_this_raster.split(":")
    ESPG_this_raster = ESPG_this_raster_split[-1] 
    print ("This ESPG is: "+str(ESPG_this_raster))
    # Get extent of raster
    [xmin,xmax,ymin,ymax] = GetRasterExtent(FullFilename)
    xproj_extent = xmax-xmin
    yproj_extent = ymax-ymin
    # Create ring
    ring = ogr.Geometry(ogr.wkbLinearRing)
    ring.AddPoint(xmin, ymin)
    ring.AddPoint(xmin, ymax)
    ring.AddPoint(xmax, ymax)
    ring.AddPoint(xmax, ymin)
    # Create a coordinate transformation
    source = osr.SpatialReference()
    target = osr.SpatialReference()
    transform = osr.CoordinateTransformation(source, target)
    # now transform the ring so you can get coordinates in lat-long
    # now get the xmin,ymin, and xmax, ymax coords in lat-long
    pt1 = ring.GetPoint(0)
    min_long = pt1[0]
    min_lat = pt1[1]
    pt2 = ring.GetPoint(2)
    max_long = pt2[0]
    max_lat = pt2[1]
    extent_long = max_long-min_long
    extent_lat = max_lat-min_lat
    centre_long = min_long+extent_long*0.5
    centre_lat = min_lat+extent_lat*0.5
    return centre_lat, centre_long, extent_lat, extent_long, xproj_extent, yproj_extent
    # Leaving this here to show how to loop through points
    # Now try with the geometry tools
    #print("Trying ring")
    #for i in range(0, ring.GetPointCount()):
    #    pt = ring.GetPoint(i)
    #    print("The point is: ")
    #    print(str(pt[0])+" "+str(pt[1]))