Python osgeo.gdal.Dataset() Examples

The following are code examples for showing how to use osgeo.gdal.Dataset(). They are from open source Python projects. You can vote up the examples you like or vote down the ones you don't like.

Example 1
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 6 votes vote down vote up
def spectra_at_xy(rast, xy, gt=None, wkt=None, dd=False):
    '''
    Returns the spectral profile of the pixels indicated by the longitude-
    latitude pairs provided. Arguments:
        rast    A gdal.Dataset or NumPy array
        xy      An array of X-Y (e.g., longitude-latitude) pairs
        gt      A GDAL GeoTransform tuple; ignored for gdal.Dataset
        wkt     Well-Known Text projection information; ignored for
                gdal.Dataset
        dd      Interpret the longitude-latitude pairs as decimal degrees
    Returns a (q x p) array of q endmembers with p bands.
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        gt = rast.GetGeoTransform()
        wkt = rast.GetProjection()
        rast = rast.ReadAsArray()

    # You would think that transposing the matrix can't be as fast as
    #   transposing the coordinate pairs, however, it is.
    return spectra_at_idx(rast.transpose(), xy_to_pixel(xy,
        gt=gt, wkt=wkt, dd=dd)) 
Example 2
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 6 votes vote down vote up
def __init__(self, filepaths, band=0):
        self._datasets=[]#So they don't go out of scope and get GC'd

        #Get a reference dataset so can apply env setting to all datasets
        reference_ds=Dataset(filepaths[0])
        for f in filepaths[1:]:
            d=Dataset(f)
            reference_ds,d=reference_ds.apply_environment(d)

        vrtxml=self.buildvrt(reference_ds, filepaths, band)

        #Temp in memory VRT file
        self._filename='/vsimem/%s.vrt'%tempfile._RandomNameSequence().next()
        self.__write_vsimem__(self._filename,vrtxml)
        self._dataset=gdal.Open(self._filename)

        Dataset.__init__(self) 
Example 3
Project: pygeotools   Author: dshean   File: iolib.py    MIT License 6 votes vote down vote up
def gdal2np_dtype(b):
    """
    Get NumPy datatype that corresponds with GDAL RasterBand datatype
    Input can be filename, GDAL Dataset, GDAL RasterBand, or GDAL integer dtype
    """
    dt_dict = gdal_array.codes
    if isinstance(b, str):
        b = gdal.Open(b)
    if isinstance(b, gdal.Dataset):
        b = b.GetRasterBand(1)
    if isinstance(b, gdal.Band):
        b = b.DataType
    if isinstance(b, int):
        np_dtype = dt_dict[b]
    else:
        np_dtype = None
        print("Input must be GDAL Dataset or RasterBand object")
    return np_dtype

#Replace nodata value in GDAL band 
Example 4
Project: GeoPy   Author: aerler   File: gdal.py    GNU General Public License v3.0 5 votes vote down vote up
def getGridDef(var):
  ''' Get a GridDefinition instance from a GDAL enabled Variable of Dataset. '''
  if 'gdal' not in var.__dict__: raise GDALError
  # instantiate GridDefinition
  return GridDefinition(name=var.name+'_grid', projection=var.projection, geotransform=var.geotransform, 
                        size=var.mapSize, xlon=var.xlon, ylat=var.ylat)


## gid pickle functions 
Example 5
Project: GeoPy   Author: aerler   File: gdal.py    GNU General Public License v3.0 5 votes vote down vote up
def getGeotransform(xlon=None, ylat=None, geotransform=None):
  ''' Function to check or infer GDAL geotransform from coordinate axes. 
      Note that due to machine-precision errors, recomputing the geotransform from coordinates can cause problems. '''
  if geotransform is None:  # infer geotransform from axes
    if not isinstance(ylat, Axis) or not isinstance(xlon, Axis): raise TypeError     
    if xlon.data and ylat.data:
      # infer GDAL geotransform vector from  coordinate vectors (axes)
      dx = xlon[1] - xlon[0]; dy = ylat[1] - ylat[0]
      # assert (np.diff(xlon).mean() == dx).all() and (np.diff(ylat).mean() == dy).all(), 'Coordinate vectors have to be uniform!'
      ulx = xlon[0] - dx / 2.; uly = ylat[0] - dy / 2.  # coordinates of upper left corner (same for source and sink)
      # GT(2) & GT(4) are zero for North-up; GT(1) & GT(5) are pixel width and height; (GT(0),GT(3)) is the top left corner
      geotransform = (float(ulx), float(dx), 0., float(uly), 0., float(dy))
    else: raise DataError("Coordinate vectors are required to infer GDAL geotransform vector.")
  else:  # check given geotransform
    geotransform = tuple(float(f) for f in geotransform)
    if xlon.data or ylat.data:
      # check if GDAL geotransform vector is consistent with coordinate vectors
      if not len(geotransform) == 6:
        raise GDALError('\'geotransform\' has to be a vector or list with 6 elements.')
      dx = geotransform[1]; dy = geotransform[5]; ulx = geotransform[0]; uly = geotransform[3] 
      # assert isZero(np.diff(xlon)-dx) and isZero(np.diff(ylat)-dy), 'Coordinate vectors have to be compatible with geotransform!'
      #print geotransform
      #print ulx + dx / 2., xlon[0], uly + dy / 2., ylat[0]
      # coordinates of upper left corner (same for source and sink)       
      if not isEqual(ulx, float(xlon[0]) - dx / 2.): raise GDALError('{} != {}'.format(ulx, float(xlon[0]) - dx / 2.))
      if not isEqual(uly, float(ylat[0]) - dy / 2.): raise GDALError('{} != {}'.format(uly, float(ylat[0]) - dy / 2.))
    else: 
      if not ( len(geotransform) == 6 and all(isFloat(geotransform)) ):
        raise GDALError('\'geotransform\' has to be a vector or list of 6 floating-point numbers.')
  # return results
  return geotransform


## functions to add GDAL functionality to existing Variable and Dataset instances 
Example 6
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 5 votes vote down vote up
def as_raster(path):
    '''
    A convenience function for opening a raster and accessing its spatial
    information; takes a single string argument. Arguments:
        path    The path of the raster file to open as a gdal.Dataset
    '''
    ds = gdal.Open(path)
    gt = ds.GetGeoTransform()
    wkt = ds.GetProjection()
    return (ds, gt, wkt) 
Example 7
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 5 votes vote down vote up
def cfmask(mask, mask_values=(1,2,3,4,255), nodata=-9999):
    '''
    Returns a binary mask according to the CFMask algorithm results for the
    image; mask has True for water, cloud, shadow, and snow (if any) and False
    everywhere else. More information can be found:
        https://landsat.usgs.gov/landsat-surface-reflectance-quality-assessment

    Landsat 4-7 Pre-Collection pixel_qa values to be masked:
        mask_values = (1, 2, 3, 4)

    Landsat 4-7 Collection 1 pixel_qa values to be masked (for "Medium" confidence):
        mask_values = (1, 68, 72, 80, 112, 132, 136, 144, 160, 176, 224)

    Landsat 8 Collection 1 pixel_qa values to be masked (for "Medium" confidence):
        mask_values = (1, 324, 328, 386, 388, 392, 400, 416, 432, 480, 832, 836, 840, 848, 864, 880, 900, 904, 912, 928, 944, 992, 1024)

    Arguments:
        mask        A gdal.Dataset or a NumPy array
        mask_path   The path to an EOS HDF4 CFMask raster
        mask_values The values in the mask that correspond to NoData pixels
        nodata      The NoData value; defaults to -9999.
    '''
    if not isinstance(mask, np.ndarray):
        maskr = mask.ReadAsArray()

    else:
        maskr = mask.copy()

    # Mask according to bit-packing described here:
    # https://landsat.usgs.gov/landsat-surface-reflectance-quality-assessment
    maskr = np.in1d(maskr.reshape((maskr.shape[0] * maskr.shape[1])), mask_values)\
        .reshape((1, maskr.shape[0], maskr.shape[1])).astype(np.int0)

    return maskr 
Example 8
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 5 votes vote down vote up
def clean_mask(rast):
    '''
    Clips the values in a mask to the interval [0, 1]; values larger than 1
    become 1 and values smaller than 0 become 0.
    Arguments:
        rast    An input gdal.Dataset or numpy.array instance
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    return np.clip(rastr, a_min=0, a_max=1) 
Example 9
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 5 votes vote down vote up
def copy_nodata(source, target, nodata=-9999):
    '''
    Copies the NoData values from a source raster or raster array into a
    target raster or raster array. That is, source's NoData values are
    embedded in target. This is useful, for instance, when you want to mask
    out dropped scanlines in a Landsat 7 image; these areas are NoData in the
    EOS HDF but they are not included in the QA mask. Arguments:
        source  A gdal.Dataset or a NumPy array
        target  A gdal.Dataset or a NumPy array
        nodata  The NoData value to look for (and embed)
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(source, np.ndarray):
        source = source.ReadAsArray()

    if not isinstance(target, np.ndarray):
        target = target.ReadAsArray()

    else:
        target = target.copy()

    assert source.ndim == target.ndim, "Source and target rasters must have the same number of axes"

    if source.ndim == 3:
        assert source.shape[1:] == target.shape[1:], "Source and target rasters must have the same shape (not including band axis)"
        return np.where(source[0,...] == nodata, nodata, target)

    else:
        assert source.shape == target.shape, "Source and target rasters must have the same shape"
        return np.where(source == nodata, nodata, target) 
Example 10
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 5 votes vote down vote up
def dump_raster(rast, rast_path, driver='GTiff', gdt=None, nodata=None):
    '''
    Creates a raster file from a given gdal.Dataset instance. Arguments:
        rast        A gdal.Dataset; does NOT accept NumPy array
        rast_path   The path of the output raster file
        driver      The name of the GDAL driver to use (determines file type)
        gdt         The GDAL data type to use, e.g., see gdal.GDT_Float32
        nodata      The NoData value; defaults to -9999.
    '''
    if gdt is None:
        gdt = rast.GetRasterBand(1).DataType
    driver = gdal.GetDriverByName(driver)
    sink = driver.Create(
        rast_path, rast.RasterXSize, rast.RasterYSize, rast.RasterCount, int(gdt))
    assert sink is not None, 'Cannot create dataset; there may be a problem with the output path you specified'
    sink.SetGeoTransform(rast.GetGeoTransform())
    sink.SetProjection(rast.GetProjection())

    for b in range(1, rast.RasterCount + 1):
        dat = rast.GetRasterBand(b).ReadAsArray()
        sink.GetRasterBand(b).WriteArray(dat)
        sink.GetRasterBand(b).SetStatistics(*map(np.float64,
            [dat.min(), dat.max(), dat.mean(), dat.std()]))

        if nodata is None:
            nodata = rast.GetRasterBand(b).GetNoDataValue()

            if nodata is None:
                nodata = -9999

        sink.GetRasterBand(b).SetNoDataValue(np.float64(nodata))

    sink.FlushCache() 
Example 11
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 5 votes vote down vote up
def mask_by_query(rast, query, invert=False, nodata=-9999):
    '''
    Mask pixels (across bands) that match a query in any one band or all
    bands. For example: `query = rast[1,...] < -25` queries those pixels
    with a value less than -25 in band 2; these pixels would be masked
    (if `invert=False`). By default, the pixels that are queried are
    masked, but if `invert=True`, the query serves to select pixels NOT
    to be masked (`np.invert()` can also be called on the query before
    calling this function to achieve the same effect). Arguments:
        rast    A gdal.Dataset or numpy.array instance
        query   A NumPy boolean array representing the result of a query
        invert  True to invert the query
        nodata  The NoData value to apply in the masking
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    shp = rastr.shape
    if query.shape != rastr.shape:
        assert len(query.shape) == 2 or len(query.shape) == len(shp), 'Query must either be 2-dimensional (single-band) or have a dimensionality equal to the raster array'
        assert shp[-2] == query.shape[-2] and shp[-1] == query.shape[-1], 'Raster and query must be conformable arrays in two dimensions (must have the same extent)'

        # Transform query into a 1-band array and then into a multi-band array
        query = query.reshape((1, shp[-2], shp[-1])).repeat(shp[0], axis=0)

    # Mask out areas that match the query
    if invert:
        rastr[np.invert(query)] = nodata

    else:
        rastr[query] = nodata

    return rastr 
Example 12
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 5 votes vote down vote up
def saturation_mask(rast, saturation_value=10000, nodata=-9999):
    '''
    Returns a binary mask that has True for saturated values (e.g., surface
    reflectance values greater than 16,000, however, SR values are only
    considered valid on the range [0, 10,000]) and False everywhere else.
    Arguments:
        rast                A gdal.Dataset or NumPy array
        saturation_value    The value beyond which pixels are considered
                            saturated
        nodata              The NoData value; defaults to -9999.
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    # Create a baseline "nothing is saturated in any band" raster
    mask = np.empty((1, rastr.shape[1], rastr.shape[2]))
    mask.fill(False)

    # Update the mask for saturation in any band
    for i in range(rastr.shape[0]):
        np.logical_or(mask, rastr[i,...] > saturation_value, out=mask)

    return mask 
Example 13
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 5 votes vote down vote up
def subarray(rast, filtered_value=-9999, indices=False):
    '''
    Given a (p x m x n) raster (or array), returns a (p x z) subarray where
    z is the number of cases (pixels) that do not contain the filtered value
    (in any band, in the case of a multi-band image). Arguments:
        rast            The input gdal.Dataset or a NumPy array
        filtered_value  The value to remove from the raster array
        indices         If True, return a tuple: (indices, subarray)
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    shp = rastr.shape
    if len(shp) == 1:
        # If already raveled
        return rastr[rastr != filtered_value]

    if len(shp) == 2 or shp[0] == 1:
        # If a "single-band" image
        arr = rastr.reshape(1, shp[-2]*shp[-1])
        return arr[arr != filtered_value]

    # For multi-band images
    arr = rastr.reshape(shp[0], shp[1]*shp[2])
    idx = (arr != filtered_value).any(axis=0)
    if indices:
        # Return the indices as well
        rast_shp = (shp[-2], shp[-1])
        return (np.indices(rast_shp)[:,idx.reshape(rast_shp)], arr[:,idx])

    return arr[:,idx] 
Example 14
Project: unmixing   Author: arthur-e   File: lsma.py    MIT License 5 votes vote down vote up
def normalize_reflectance_within_image(
    rast, band_range=(0, 5), nodata=-9999, scale=100):
    '''
    Following Wu (2004, Remote Sensing of Environment), normalizes the
    reflectances in each pixel by the average reflectance *across bands.*
    This is an attempt to mitigate within-endmember variability. Arguments:
        rast    A gdal.Dataset or numpy.array instance
        nodata  The NoData value to use (and value to ignore)
        scale   (Optional) Wu's definition scales the normalized reflectance
                by 100 for some reason; another reasonable value would
                be 10,000 (approximating scale of Landsat reflectance units);
                set to None for no scaling.
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    shp = rastr.shape
    b0, b1 = band_range # Get the beginning, end of band range
    b1 += 1 # Ranges in Python are not inclusive, so add 1
    rastr_normalized = np.divide(
        rastr.reshape((shp[0], shp[1]*shp[2])),
        rastr[b0:b1,...].mean(axis=0).reshape((1, shp[1]*shp[2])).repeat(shp[0], axis=0))

    # Recover original shape; scale if necessary
    rastr_normalized = rastr_normalized.reshape(shp)
    if scale is not None:
        rastr_normalized = np.multiply(rastr_normalized, scale)

    # Fill in the NoData areas from the original raster
    np.place(rastr_normalized, rastr == nodata, nodata)
    return rastr_normalized 
Example 15
Project: unmixing   Author: arthur-e   File: tests.py    MIT License 5 votes vote down vote up
def test_file_raster_and_array_access(self):
        '''
        Tests that essential file reading and raster/array conversion utilities
        are working properly.
        '''
        from_as_array = as_array(os.path.join(self.test_dir, 'multi3_raster.tiff'))
        from_as_raster = as_raster(os.path.join(self.test_dir, 'multi3_raster.tiff'))
        self.assertTrue(len(from_as_array) == len(from_as_raster) == 3)
        self.assertTrue(isinstance(from_as_array[0], np.ndarray))
        self.assertTrue(isinstance(from_as_raster[0], gdal.Dataset)) 
Example 16
Project: unmixing   Author: arthur-e   File: visualize.py    MIT License 5 votes vote down vote up
def histogram(arr, valid_range=(0, 1), bins=10, normed=False, cumulative=False,
        file_path='hist.png', title=None):
    '''
    Plots a histogram for an input array over a specified range.
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(arr, np.ndarray):
        arr = arr.ReadAsArray()

    plt.hist(arr.ravel(), range=valid_range, bins=bins, normed=normed,
        cumulative=cumulative)
    if title is not None:
        plt.title(title)

    plt.savefig(file_path) 
Example 17
Project: coded   Author: bullocke   File: classify.py    MIT License 5 votes vote down vote up
def create_mask_from_vector(vector_data_path, cols, rows, geo_transform, 
			    projection, target_value=1,
                            output_fname='', dataset_format='MEM'):

    """
    Rasterize the given vector (wrapper for gdal.RasterizeLayer). 
    Return a gdal.Dataset.
    :param vector_data_path: Path to a shapefile
    :param cols: Number of columns of the result
    :param rows: Number of rows of the result
    :param geo_transform: Returned value of gdal.Dataset.GetGeoTransform 
	(coefficients for transforming between pixel/line (P,L) raster space,
	 and projection coordinates (Xp,Yp) space.
    :param projection: Projection definition string (Returned by 
	gdal.Dataset.GetProjectionRef)
    :param target_value: Pixel value for the pixels. Must be a valid 
	gdal.GDT_UInt16 value.
    :param output_fname: If the dataset_format is GeoTIFF, this is the output 
	file name
    :param dataset_format: The gdal.Dataset driver name. [default: MEM]
    """

    driver = ogr.GetDriverByName('ESRI Shapefile')
    data_source = driver.Open(vector_data_path, 0)
    if data_source is None:
        report_and_exit("File read failed: %s", vector_data_path)
    layer = data_source.GetLayer(0)
    driver = gdal.GetDriverByName(dataset_format)
    target_ds = driver.Create(output_fname, cols, rows, 1, gdal.GDT_UInt16)
    target_ds.SetGeoTransform(geo_transform)
    target_ds.SetProjection(projection)
    gdal.RasterizeLayer(target_ds, [1], layer, burn_values=[target_value])
    return target_ds 
Example 18
Project: coded   Author: bullocke   File: classify.py    MIT License 5 votes vote down vote up
def vectors_to_raster(file_paths, rows, cols, geo_transform, projection):

    """
    Rasterize, in a single image, all the vectors in the given directory.
        The data of each file will be assigned the same pixel value. This value is 
        defined by the order of the file in file_paths, starting with 1: so the 
        points/poligons/etc in the same file will be
        marked as 1, those in the second file will be 2, and so on.
    :param file_paths: Path to a directory with shapefiles
    :param rows: Number of rows of the result
    :param cols: Number of columns of the result
    :param geo_transform: Returned value of gdal.Dataset.GetGeoTransform 
	(coefficients for transforming between pixel/line (P,L) raster space, 
	and projection coordinates (Xp,Yp) space.
    :param projection: Projection definition string (Returned by 
	gdal.Dataset.GetProjectionRef)
    """

    labeled_pixels = np.zeros((rows, cols))
    for i, path in enumerate(file_paths):
        label = i+1
        logger.debug("Processing file %s: label (pixel value) %i", path, label)
        ds = create_mask_from_vector(path, cols, rows, geo_transform, projection,
                                     target_value=label)
        band = ds.GetRasterBand(1)
        a = band.ReadAsArray()
        logger.debug("Labeled pixels: %i", len(a.nonzero()[0]))
        labeled_pixels += a
        ds = None
    return labeled_pixels 
Example 19
Project: coded   Author: bullocke   File: classify.py    MIT License 5 votes vote down vote up
def write_geotiff(fname, data, geo_transform, projection, data_type=gdal.GDT_Byte):

    """
    Create a GeoTIFF file with the given data.
    :param fname: Path to a directory with shapefiles
    :param data: Number of rows of the result
    :param geo_transform: Returned value of 
	gdal.Dataset.GetGeoTransform (coefficients for transforming between 
	pixel/line (P,L) raster space, and projection coordinates (Xp,Yp) space.
    :param projection: Projection definition string (Returned by 
	gdal.Dataset.GetProjectionRef)
    """

    driver = gdal.GetDriverByName('GTiff')
    rows, cols = data.shape
    dataset = driver.Create(fname, cols, rows, 1, data_type)
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    band = dataset.GetRasterBand(1)
    band.WriteArray(data)

    metadata = {
        'TIFFTAG_COPYRIGHT': 'CC BY 4.0',
        'TIFFTAG_DOCUMENTNAME': 'classification',
        'TIFFTAG_IMAGEDESCRIPTION': 'Supervised classification.',
        'TIFFTAG_SOFTWARE': 'Python, GDAL, scikit-learn'
    }
    dataset.SetMetadata(metadata)

    dataset = None  # Close the file
    return 
Example 20
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def create_copy(self,outpath,outformat='GTIFF',options=[]):
        ok=(os.path.exists(outpath) and Env.overwrite) or (not os.path.exists(outpath))
        if ok:
            if Env.progress.enabled:callback=gdal.TermProgress_nocb
            else:callback=None
            try:                   #Is it a Band
                ds=self.dataset._dataset
            except AttributeError: #No, it's a Dataset
                ds=self._dataset
            driver=gdal.GetDriverByName(outformat)
            ds=driver.CreateCopy(outpath,ds,options=options,callback=callback)
            ds=None
            del ds
            return Dataset(outpath)
        else:raise RuntimeError('Output %s exists and overwrite is not set.'%outpath) 
Example 21
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def __write_vsimem__(self,fn,data):
        '''Write GDAL vsimem files'''
        vsifile = gdal.VSIFOpenL(fn,'w')
        size = len(data)
        gdal.VSIFWriteL(data, 1, size, vsifile)
        return gdal.VSIFCloseL(vsifile)

    #===========================================================================
    #gdal.Dataset/Band and ndarray attribute calls
    #=========================================================================== 
Example 22
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def __init__(self,band,dataset,bandnum=0):
        self._band = band
        self.dataset=dataset #Keep a link to the parent Dataset object

        self._x_size=dataset._x_size
        self._y_size=dataset._y_size
        self._nbands=1
        self._bands=[bandnum]#Keep track of band number, zero based index
        self._data_type=self.DataType
        self._srs=dataset.GetProjectionRef()
        self._gt=dataset.GetGeoTransform()
        self._block_size=self.GetBlockSize()
        self._nodata=[band.GetNoDataValue()]
        self.extent=self.__get_extent__() 
Example 23
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def get_raster_band(self,*args,**kwargs):
        '''So we can sort of treat Band and Dataset interchangeably'''
        return self 
Example 24
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def __init__(self,filepath_or_dataset=None,*args):
        gdal.UseExceptions()

        fp=filepath_or_dataset

        if type(fp) is gdal.Dataset:
            self._dataset = fp
        elif fp is not None:
            if os.path.exists(fp):
                self._dataset = gdal.Open(os.path.abspath(fp),*args)
            else:
                self._dataset = gdal.Open(fp,*args)

        #Issue 8
        self._gt=self.GetGeoTransform()
        if self._gt[5] > 0: #positive NS pixel res.
            tmp_ds = gdal.AutoCreateWarpedVRT(self._dataset)
            tmp_fn = '/vsimem/%s.vrt'%tempfile._RandomNameSequence().next()
            self._dataset = gdal.GetDriverByName('VRT').CreateCopy(tmp_fn,tmp_ds)
            self._gt = self.GetGeoTransform()

        self._x_size=self.RasterXSize
        self._y_size=self.RasterYSize
        self._nbands=self.RasterCount
        self._bands=range(self.RasterCount)
        self._data_type=self.GetRasterBand(1).DataType
        self._srs=self.GetProjectionRef()
        self._block_size=self.GetRasterBand(1).GetBlockSize()
        self._nodata=[b.GetNoDataValue() for b in self]
        
        self.extent=self.__get_extent__() 
Example 25
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def __getitem__(self, key):
        ''' Enable "somedataset[bandnum]" syntax'''
        return Band(self._dataset.GetRasterBand(key+1),self, key) #GDAL Dataset Band indexing starts at 1 
Example 26
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def __iter__(self):
        ''' Enable "for band in somedataset:" syntax'''
        for i in xrange(self.RasterCount):
            yield Band(self.GetRasterBand(i+1),self,i) #GDAL Dataset Band indexing starts at 1 
Example 27
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def get_raster_band(self,i=1): #GDAL Dataset Band indexing starts at 1
        return Band(self._dataset.GetRasterBand(i),self,i-1)

    #CamelCase synonym 
Example 28
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def __del__(self):
        try:Dataset.__del__(self)
        except:pass
        try:gdal.Unlink(self._fn)
        except:pass 
Example 29
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def __init__(self,filename,outformat='GTIFF',
                 cols=None,rows=None,bands=None,datatype=None,
                 srs='',gt=[],nodata=[],options=[],prototype_ds=None):
        use_exceptions=gdal.GetUseExceptions()
        gdal.UseExceptions()

        if prototype_ds is not None:
            if cols is None:cols=prototype_ds._x_size
            if rows is None:rows=prototype_ds._y_size
            if bands is None:bands=prototype_ds._nbands
            if datatype is None:datatype=prototype_ds._data_type
            if not srs:srs=prototype_ds._srs
            if not gt:gt=prototype_ds._gt
            if nodata==[]:nodata=prototype_ds._nodata
        else:
            if cols is None:raise TypeError('Expected "cols" or "prototype_ds", got None')
            if rows is None:raise TypeError('Expected "rows" or "prototype_ds", got None')
            if bands is None:raise TypeError('Expected "bands" or "prototype_ds", got None')
            if datatype is None:raise TypeError('Expected "datatype" or "prototype_ds", got None')
            if not gt:gt=(0.0, 1.0, 0.0, 0.0, 0.0, 1.0)

        self._filename=filename
        self._driver=gdal.GetDriverByName(outformat)
        self._dataset=self._driver.Create (self._filename,cols,rows,bands,datatype,options)

        if not use_exceptions:gdal.DontUseExceptions()
        self._dataset.SetGeoTransform(gt)
        self._dataset.SetProjection(srs)
        for i,val in enumerate(nodata[:bands]):
            try:self._dataset.GetRasterBand(i+1).SetNoDataValue(val)
            except TypeError:pass
        Dataset.__init__(self) 
Example 30
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def __init__(self,dataset_or_band, wkt_srs, snap_ds=None, snap_cellsize=None):

        use_exceptions=gdal.GetUseExceptions()
        gdal.UseExceptions()

        self._simple_fn='/vsimem/%s.vrt'%tempfile._RandomNameSequence().next()
        self._warped_fn='/vsimem/%s.vrt'%tempfile._RandomNameSequence().next()

        try:                   #Is it a Band
            orig_ds=dataset_or_band.dataset._dataset
        except AttributeError: #No, it's a Dataset
            orig_ds=dataset_or_band._dataset

        try: #Generate a warped VRT
            warped_ds=gdal.AutoCreateWarpedVRT(orig_ds,orig_ds.GetProjection(),wkt_srs, Env.resampling)
            #AutoCreateWarpedVRT doesn't create a vsimem filename and we need one
            warped_ds=gdal.GetDriverByName('VRT').CreateCopy(self._warped_fn,warped_ds)

        except Exception as e:
            raise RuntimeError('Unable to project on the fly. '+e.message)

        #Disable the following check as this will allow us to use a WarpedDataset to
        #resample Datasets and creating an AutoCreateWarpedVRT where input srs==output srs
        #will allways fail the test below...
        #if warped_ds.GetGeoTransform()==orig_ds.GetGeoTransform():
        #    raise RuntimeError('Unable to project on the fly. Make sure all input datasets have projections set.')

        if snap_ds:warped_ds=self._modify_vrt(warped_ds, orig_ds, snap_ds, snap_cellsize)
        self._dataset=self._create_simple_VRT(warped_ds,dataset_or_band)

        if not use_exceptions:gdal.DontUseExceptions()
        Dataset.__init__(self) 
Example 31
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def __del__(self):
        try:Dataset.__del__(self)
        except:pass
        try:gdal.Unlink(self._warped_fn)
        except:pass
        try:gdal.Unlink(self._simple_fn)
        except:pass 
Example 32
Project: geobricks_qgis_plugin_trmm   Author: geobricks   File: gdal_dataset.py    GNU General Public License v2.0 5 votes vote down vote up
def buildvrt(self, reference_ds, filepaths, band):
        ''' Create a simple VRT stack'''
        vrt=[]
        vrt.append('<VRTDataset rasterXSize="%s" rasterYSize="%s">' % (reference_ds.RasterXSize,reference_ds.RasterYSize))
        vrt.append('  <SRS>%s</SRS>' % reference_ds.GetProjection())
        vrt.append('  <GeoTransform>%s</GeoTransform>' % ', '.join(map(str,reference_ds.GetGeoTransform())))

        for f in filepaths:
            d=Dataset(f)
            reference_ds,d=reference_ds.apply_environment(d)
            self._datasets.append(d)

            rb=d.GetRasterBand(band+1) #gdal band index start at 1
            nodata=rb.GetNoDataValue()
            path=d.GetDescription()
            rel=not os.path.isabs(path)
            vrt.append('  <VRTRasterBand dataType="%s" band="%s">' % (gdal.GetDataTypeName(rb.DataType), band+1))
            vrt.append('    <SimpleSource>')
            vrt.append('      <SourceFilename relativeToVRT="%s">%s</SourceFilename>' % (int(rel),path))
            vrt.append('      <SourceBand>%s</SourceBand>'%(band+1))
            vrt.append('      <SrcRect xOff="0" yOff="0" xSize="%s" ySize="%s" />' % (d.RasterXSize,d.RasterYSize))
            vrt.append('      <DstRect xOff="0" yOff="0" xSize="%s" ySize="%s" />' % (d.RasterXSize,d.RasterYSize))
            vrt.append('    </SimpleSource>')
            if nodata is not None: # 0 is a valid value
                vrt.append('    <NoDataValue>%s</NoDataValue>' % nodata)
            vrt.append('  </VRTRasterBand>')
        vrt.append('</VRTDataset>')

        vrt='\n'.join(vrt)
        return vrt 
Example 33
Project: pygeotools   Author: dshean   File: warplib.py    MIT License 5 votes vote down vote up
def memwarp(src_ds, res=None, extent=None, t_srs=None, r=None, oudir=None, dst_ndv=0, verbose=True):
    """Helper function that calls warp for single input Dataset with output to memory (GDAL Memory Driver)
    """
    driver = iolib.mem_drv
    return warp(src_ds, res, extent, t_srs, r, driver=driver, dst_ndv=dst_ndv, verbose=verbose)

#Use this to warp directly to output file - no need to write to memory then CreateCopy 
Example 34
Project: pygeotools   Author: dshean   File: warplib.py    MIT License 5 votes vote down vote up
def diskwarp(src_ds, res=None, extent=None, t_srs=None, r='cubic', outdir=None, dst_fn=None, dst_ndv=None, verbose=True):
    """Helper function that calls warp for single input Dataset with output to disk (GDAL GeoTiff Driver)
    """
    if dst_fn is None:
        dst_fn = os.path.splitext(src_ds.GetFileList()[0])[0]+'_warp.tif'
    if outdir is not None:
        dst_fn = os.path.join(outdir, os.path.basename(dst_fn))  
    driver = iolib.gtif_drv
    dst_ds = warp(src_ds, res, extent, t_srs, r, driver, dst_fn, dst_ndv=dst_ndv, verbose=verbose, options=iolib.gdal_opt)
    #Write out
    dst_ds = None
    #Now reopen ds from disk
    dst_ds = gdal.Open(dst_fn)
    return dst_ds 
Example 35
Project: pygeotools   Author: dshean   File: iolib.py    MIT License 5 votes vote down vote up
def get_sub_dim(src_ds, scale=None, maxdim=1024):
    """Compute dimensions of subsampled dataset 

    Parameters
    ----------
    ds : gdal.Dataset 
        Input GDAL Datset
    scale : int, optional
        Scaling factor
    maxdim : int, optional 
        Maximum dimension along either axis, in pixels
    
    Returns
    -------
    ns
        Numper of samples in subsampled output
    nl
        Numper of lines in subsampled output
    scale 
        Final scaling factor
    """
    ns = src_ds.RasterXSize
    nl = src_ds.RasterYSize
    maxdim = float(maxdim)
    if scale is None:
        scale_ns = ns/maxdim
        scale_nl = nl/maxdim
        scale = max(scale_ns, scale_nl)
    #Need to check to make sure scale is positive real 
    if scale > 1:
        ns = int(round(ns/scale))
        nl = int(round(nl/scale))
    return ns, nl, scale 
Example 36
Project: pygdal-json   Author: geospatial-jeff   File: test_vrt.py    GNU General Public License v3.0 5 votes vote down vote up
def test_to_gdal(self):
        with self.open_vrt(self.warpedvrt) as vrt:
            vrt.warp(dstSRS=3857)
            out_ds = utils.to_gdal(vrt)
            self.assertEqual(type(out_ds), gdal.Dataset)

        with self.open_vrt(self.translatevrt) as vrt:
            vrt.translate(bandList=[3, 2])
            out_ds = utils.to_gdal(vrt)
            self.assertEqual(type(out_ds), gdal.Dataset) 
Example 37
Project: pygdal-json   Author: geospatial-jeff   File: test_vrt.py    GNU General Public License v3.0 5 votes vote down vote up
def test_to_file(self):
        with self.open_vrt(self.warpedvrt) as vrt:
            vrt.warp(dstSRS=3857)
            utils.to_file(vrt, "/vsimem/save_warp.tif")
            ds = gdal.Open("/vsimem/save_warp.tif")
            self.assertEqual(type(ds), gdal.Dataset)

        with self.open_vrt(self.translatevrt) as vrt:
            vrt.translate(bandList=[3, 2])
            utils.to_file(vrt, "/vsimem/save_translate.tif")
            ds = gdal.Open("/vsimem/save_translate.tif")
            self.assertEqual(type(ds), gdal.Dataset) 
Example 38
Project: GeoPy   Author: aerler   File: gdal.py    GNU General Public License v3.0 4 votes vote down vote up
def getProjection(var, projection=None):
  ''' Function to infere GDAL parameters from a Variable or Dataset '''
  if not isinstance(var, (Variable, Dataset)): raise TypeError
  # infer map axes and projection parameters
  if projection is None:  # can still infer some useful info
    if var.hasAxis('x') and var.hasAxis('y'):
      isProjected = True; xlon = var.x; ylat = var.y
    elif var.hasAxis('lon') and var.hasAxis('lat'):
      isProjected = False; xlon = var.lon; ylat = var.lat
      projection = osr.SpatialReference() 
      projection.SetWellKnownGeogCS('WGS84')  # normal lat/lon projection
    else: xlon = None; ylat = None; isProjected = None
  else: 
    # figure out projection
    if isinstance(projection, dict): projection = getProjFromDict(projection)
    # assume projection is set
    if not isinstance(projection, osr.SpatialReference): 
      raise TypeError('\'projection\' has to be a GDAL SpatialReference object.')              
    isProjected = projection.IsProjected()
    if isProjected: 
#       if not var.hasAxis('x') and var.hasAxis('y'): 
#         raise AxisError, 'Horizontal axes for projected GDAL variables have to \'x\' and \'y\'.'
      if var.hasAxis('x') and var.hasAxis('y'):
        xlon = var.x; ylat = var.y
      else: xlon = None; ylat = None
      # N.B.: staggered variables are usually only staggered in one dimension, but these variables can not
      #       be treated as a GDAL variable, because their geotransform would be different
    else: 
#       if not var.hasAxis('lon') and var.hasAxis('lat'):
#         raise AxisError, 'Horizontal axes for non-projected GDAL variables have to be \'lon\' and \'lat\''
      if var.hasAxis('lon') and var.hasAxis('lat'):
        xlon = var.lon; ylat = var.lat
      else: xlon = None; ylat = None    
      # N.B.: staggered variables are usually only staggered in one dimension, but these variables can not
      #       be treated as a GDAL variable, because their geotransform would be different
  # if the variable is map-like, add GDAL properties
  if xlon is not None and ylat is not None:
    # check axes
    axstr = "'x' and 'y'" if isProjected else "'lon' and 'lat'"
    if not isinstance(xlon, Axis) and not isinstance(ylat, Axis): 
      raise AxisError("Error: attributes {:s} have to be axes.".format(axstr))
    # check map axes order for variables
    if isinstance(var,Variable):
        lgdal = ( var.axes[-2] == ylat and var.axes[-1] == xlon ) # we need the (y,x) or (lat,lon) order for the map
    else: lgdal = True # for datasets the order does not matter
  else: lgdal = False   
  # return
  return lgdal, projection, isProjected, xlon, ylat 
Example 39
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 4 votes vote down vote up
def binary_mask(rast, mask, nodata=-9999, invert=False):
    '''
    Applies an arbitrary, binary mask (data in [0,1]) where pixels with
    a value of 1 are pixels to be masked out. Arguments:
        rast    A gdal.Dataset or a NumPy array
        mask    A gdal.Dataset or a NumPy array
        nodata  The NoData value; defaults to -9999.
        invert  Invert the mask? (tranpose meaning of 0 and 1); defaults to False.
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    if not isinstance(mask, np.ndarray):
        maskr = mask.ReadAsArray()

    else:
        maskr = mask.copy()

    if not np.alltrue(np.equal(rastr.shape[-2:], maskr.shape[-2:])):
        raise ValueError('Raster and mask do not have the same shape')

    # Convert Boolean arrays to ones and zeros
    if maskr.dtype == bool:
        maskr = maskr.astype(np.int0)

    # Transform into a "1-band" array and apply the mask
    if maskr.shape != rastr.shape:
        maskr = maskr.reshape((1, maskr.shape[-2], maskr.shape[-1]))\
            .repeat(rastr.shape[0], axis=0) # Copy the mask across the "bands"

    # TODO Compare to place(), e.g.,
    # np.place(rastr, mask.repeat(rastr.shape[0], axis=0), (nodata,))
    # Mask out areas that match the mask (==1)
    if invert:
        rastr[maskr < 1] = nodata

    else:
        rastr[maskr > 0] = nodata

    return rastr 
Example 40
Project: unmixing   Author: arthur-e   File: utils.py    MIT License 4 votes vote down vote up
def mask_ledaps_qa(rast, mask, nodata=-9999):
    '''
    Applies a given LEDAPS QA mask to a raster. It's unclear how these
    bit-packed QA values ought to be converted back into 16-bit binary numbers:

    "{0:b}".format(42).zfill(16) # Convert binary to decimal padded left?
    "{0:b}".format(42).ljust(16, '0') # Or convert ... padded right?

    The temporary solution is to use the most common (modal) value as the
    "clear" pixel classification and discard everything else. We'd like to
    just discard pixels above a certain value knowing that everything above
    this threshold has a certain bit-packed QA meanining. For example, mask
    pixels with QA values greater than or equal to 12287:

    int("1000000000000000", 2) == 32768 # Maybe clouds
    int("0010000000000000", 2) == 12287 # Maybe cirrus

    Similarly, we'd like to discard pixels at or below 4, as these small binary
    numbers correspond to dropped frames, desginated fill values, and/or
    terrain occlusion. Arguments:
        rast    A gdal.Dataset or a NumPy array
        mask    A gdal.Dataset or a NumPy array
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rast = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    if not isinstance(mask, np.ndarray):
        maskr = mask.ReadAsArray()

    else:
        maskr = mask.copy()

    # Since the QA output is so unreliable (e.g., clouds are called water),
    #   we take the most common QA bit-packed value and assume it refers to
    #   the "okay" pixels
    mode = np.argmax(np.bincount(maskr.ravel()))
    assert mode > 4 and mode < 12287, 'The modal value corresponds to a known error value'
    maskr[np.isnan(maskr)] = 0
    maskr[maskr != mode] = 0
    maskr[maskr == mode] = 1

    # Transform into a "1-band" array and apply the mask
    maskr = maskr.reshape((1, maskr.shape[0], maskr.shape[1]))\
        .repeat(rastr.shape[0], axis=0) # Copy the mask across the "bands"
    rastr[maskr == 0] = nodata
    return rastr 
Example 41
Project: pygeotools   Author: dshean   File: warplib.py    MIT License 4 votes vote down vote up
def parse_srs(t_srs, src_ds_list=None):
    """Parse arbitrary input t_srs

    Parameters
    ----------
    t_srs : str or gdal.Dataset or filename
        Arbitrary input t_srs 
    src_ds_list : list of gdal.Dataset objects, optional
        Needed if specifying 'first' or 'last'

    Returns
    -------
    t_srs : osr.SpatialReference() object
        Output spatial reference system
    """
    if t_srs is None and src_ds_list is None:
        print("Input t_srs and src_ds_list are both None")
    else:
        if t_srs is None:
            t_srs = 'first'
        if t_srs == 'first' and src_ds_list is not None:
            t_srs = geolib.get_ds_srs(src_ds_list[0])
        elif t_srs == 'last' and src_ds_list is not None:
            t_srs = geolib.get_ds_srs(src_ds_list[-1])
        elif t_srs == 'source' and src_ds_list is not None:
            #Assume ds to be warped is first in ds_list
            t_srs = geolib.get_ds_srs(src_ds_list[0])
        elif isinstance(t_srs, osr.SpatialReference): 
            pass
        elif isinstance(t_srs, gdal.Dataset):
            t_srs = geolib.get_ds_srs(t_srs)
        elif isinstance(t_srs, str) and os.path.exists(t_srs): 
            t_srs = geolib.get_ds_srs(gdal.Open(t_srs))
        elif isinstance(t_srs, str):
            temp = osr.SpatialReference()
            if 'EPSG' in t_srs.upper():
                epsgcode = int(t_srs.split(':')[-1])
                temp.ImportFromEPSG(epsgcode)
            elif 'proj' in t_srs:
                temp.ImportFromProj4(t_srs)
            else:
                #Assume the user knows what they are doing
                temp.ImportFromWkt(t_srs)
            t_srs = temp
        else:
            t_srs = None
    return t_srs 
Example 42
Project: pygeotools   Author: dshean   File: warplib.py    MIT License 4 votes vote down vote up
def parse_res(res, src_ds_list=None, t_srs=None):
    """Parse arbitrary input res 

    Parameters
    ----------
    res : str or gdal.Dataset or filename or float
        Arbitrary input res 
    src_ds_list : list of gdal.Dataset objects, optional
        Needed if specifying 'first' or 'last'
    t_srs : osr.SpatialReference() object 
        Projection for res calculations, optional

    Returns
    -------
    res : float 
        Output resolution
        None if source resolution should be preserved
    """
    #Default to using first t_srs for res calculations
    #Assumes src_ds_list is not None
    t_srs = parse_srs(t_srs, src_ds_list)

    #Valid options for res
    res_str_list = ['first', 'last', 'min', 'max', 'mean', 'med', 'common_scale_factor']

    #Compute output resolution in t_srs
    if res in res_str_list and src_ds_list is not None:
        #Returns min, max, mean, med
        res_stats = geolib.get_res_stats(src_ds_list, t_srs=t_srs)
        if res == 'first':
            res = geolib.get_res(src_ds_list[0], t_srs=t_srs, square=True)[0]
        elif res == 'last':
            res = geolib.get_res(src_ds_list[-1], t_srs=t_srs, square=True)[0]
        elif res == 'min':
            res = res_stats[0]
        elif res == 'max':
            res = res_stats[1]
        elif res == 'mean':
            res = res_stats[2]
        elif res == 'med':
            res = res_stats[3]
        elif res == 'common_scale_factor':
            #Determine res to upsample min and downsample max by constant factor
            res = np.sqrt(res_stats[1]/res_stats[0]) * res_stats[0]
    elif res == 'source':
        res = None
    elif isinstance(res, gdal.Dataset):
        res = geolib.get_res(res, t_srs=t_srs, square=True)[0]
    elif isinstance(res, str) and os.path.exists(res): 
        res = geolib.get_res(gdal.Open(res), t_srs=t_srs, square=True)[0]
    else:
        res = float(res)
    return res 
Example 43
Project: pygeotools   Author: dshean   File: warplib.py    MIT License 4 votes vote down vote up
def parse_extent(extent, src_ds_list=None, t_srs=None):
    """Parse arbitrary input extent

    Parameters
    ----------
    extent : str or gdal.Dataset or filename or list of float
        Arbitrary input extent
    src_ds_list : list of gdal.Dataset objects, optional
        Needed if specifying 'first', 'last', 'intersection', or 'union'
    t_srs : osr.SpatialReference() object, optional 
        Projection for res calculations

    Returns
    -------
    extent : list of float 
        Output extent [xmin, ymin, xmax, ymax] 
        None if source extent should be preserved
    """

    #Default to using first t_srs for extent calculations
    if t_srs is not None:
        t_srs = parse_srs(t_srs, src_ds_list)

    #Valid strings
    extent_str_list = ['first', 'last', 'intersection', 'union']

    if extent in extent_str_list and src_ds_list is not None:
        if len(src_ds_list) == 1 and (extent == 'intersection' or extent == 'union'):
            extent = None
        elif extent == 'first':
            extent = geolib.ds_geom_extent(src_ds_list[0], t_srs=t_srs)
            #extent = geolib.ds_extent(src_ds_list[0], t_srs=t_srs)
        elif extent == 'last':
            extent = geolib.ds_geom_extent(src_ds_list[-1], t_srs=t_srs)
            #extent = geolib.ds_extent(src_ds_list[-1], t_srs=t_srs)
        elif extent == 'intersection':
            #By default, compute_intersection takes ref_srs from ref_ds
            extent = geolib.ds_geom_intersection_extent(src_ds_list, t_srs=t_srs)
            if len(src_ds_list) > 1 and extent is None:
                sys.exit("Input images do not intersect")
        elif extent == 'union':
            #Need to clean up union t_srs handling
            extent = geolib.ds_geom_union_extent(src_ds_list, t_srs=t_srs)
    elif extent == 'source':
        extent = None
    elif isinstance(extent, gdal.Dataset):
        extent = geolib.ds_geom_extent(extent, t_srs=t_srs)
    elif isinstance(extent, str) and os.path.exists(extent): 
        extent = geolib.ds_geom_extent(gdal.Open(extent), t_srs=t_srs)
    elif isinstance(extent, (list, tuple, np.ndarray)):
        extent = list(extent)
    else:
        extent = [float(i) for i in extent.split(' ')]
    return extent 
Example 44
Project: pygeotools   Author: dshean   File: iolib.py    MIT License 4 votes vote down vote up
def ds_getma_sub(src_ds, bnum=1, scale=None, maxdim=1024., return_ds=False):    
    """Load a subsampled array, rather than full resolution

    This is useful when working with large rasters

    Uses buf_xsize and buf_ysize options from GDAL ReadAsArray method.

    Parameters
    ----------
    ds : gdal.Dataset 
        Input GDAL Datset
    bnum : int, optional
        Band number
    scale : int, optional
        Scaling factor
    maxdim : int, optional 
        Maximum dimension along either axis, in pixels
    
    Returns
    -------
    np.ma.array    
        Masked array containing raster values
    """
    #print src_ds.GetFileList()[0]
    b = src_ds.GetRasterBand(bnum)
    b_ndv = get_ndv_b(b)
    ns, nl, scale = get_sub_dim(src_ds, scale, maxdim)
    #The buf_size parameters determine the final array dimensions
    b_array = b.ReadAsArray(buf_xsize=ns, buf_ysize=nl)
    bma = np.ma.masked_values(b_array, b_ndv)
    out = bma
    if return_ds:
        dtype = src_ds.GetRasterBand(1).DataType
        src_ds_sub = gdal.GetDriverByName('MEM').Create('', ns, nl, 1, dtype)
        gt = np.array(src_ds.GetGeoTransform())
        gt[[1,5]] = gt[[1,5]]*scale
        src_ds_sub.SetGeoTransform(list(gt))
        src_ds_sub.SetProjection(src_ds.GetProjection())
        b = src_ds_sub.GetRasterBand(1)
        b.WriteArray(bma)
        b.SetNoDataValue(b_ndv)
        out = (bma, src_ds_sub)
    return out

#Note: need to consolidate with warplib.writeout (takes ds, not ma)
#Add option to build overviews when writing GTiff
#Input proj must be WKT