Python rasterio.band() Examples

The following are 18 code examples of rasterio.band(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module rasterio , or try the search function .
Example #1
Source File: rotated_mapping_tools.py    From LSDMappingTools with MIT License 7 votes vote down vote up
def ConvertRaster2LatLong(InputRasterFile,OutputRasterFile):

    """
    Convert a raster to lat long WGS1984 EPSG:4326 coordinates for global plotting

    MDH

    """

    # import modules
    import rasterio
    from rasterio.warp import reproject, calculate_default_transform as cdt, Resampling

    # read the source raster
    with rasterio.open(InputRasterFile) as src:
        #get input coordinate system
        Input_CRS = src.crs
        # define the output coordinate system
        Output_CRS = {'init': "epsg:4326"}
        # set up the transform
        Affine, Width, Height = cdt(Input_CRS,Output_CRS,src.width,src.height,*src.bounds)
        kwargs = src.meta.copy()
        kwargs.update({
            'crs': Output_CRS,
            'transform': Affine,
            'affine': Affine,
            'width': Width,
            'height': Height
        })

        with rasterio.open(OutputRasterFile, 'w', **kwargs) as dst:
            for i in range(1, src.count+1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.affine,
                    src_crs=src.crs,
                    dst_transform=Affine,
                    dst_crs=Output_CRS,
                    resampling=Resampling.bilinear) 
Example #2
Source File: local_io.py    From eo-learn with MIT License 6 votes vote down vote up
def __init__(self, feature, folder=None, *, timestamp_size=None, **kwargs):
        """
        :param feature: EOPatch feature into which data will be imported
        :type feature: (FeatureType, str)
        :param folder: A directory containing image files or a path of an image file
        :type folder: str
        :param timestamp_size: In case data will be imported into time-dependant feature this parameter can be used to
            specify time dimension. If not specified, time dimension will be the same as size of FeatureType.TIMESTAMP
            feature. If FeatureType.TIMESTAMP does not exist it will be set to 1.
            When converting data into a feature channels of given tiff image should be in order
            T(1)B(1), T(1)B(2), ..., T(1)B(N), T(2)B(1), T(2)B(2), ..., T(2)B(N), ..., ..., T(M)B(N)
            where T and B are the time and band indices.
        :type timestamp_size: int
        :param image_dtype: Type of data of new feature imported from tiff image
        :type image_dtype: numpy.dtype
        :param no_data_value: Values where given Geo-Tiff image does not cover EOPatch
        :type no_data_value: int or float
        """
        super().__init__(feature, folder=folder, **kwargs)

        self.timestamp_size = timestamp_size 
Example #3
Source File: rasterlayer.py    From Pyspatialml with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, band):

        # access inherited methods/attributes overridden by __init__
        super().__init__(band)

        # rasterlayer specific attributes
        self.bidx = band.bidx
        self.dtype = band.dtype
        self.nodata = band.ds.nodata
        self.file = band.ds.files[0]
        self.ds = band.ds
        self.driver = band.ds.meta["driver"]
        self.meta = band.ds.meta
        self.cmap = "viridis"
        self.norm = None
        self.categorical = False
        self.names = [self._make_name(band.ds.files[0])]
        self.count = 1
        self._close = band.ds.close 
Example #4
Source File: geo_util.py    From WaterNet with MIT License 5 votes vote down vote up
def reproject_dataset(geotiff_path):
    """Project a GeoTIFF to the WGS84 coordinate reference system.
    See https://mapbox.github.io/rasterio/topics/reproject.html"""

    # We want to project the GeoTIFF coordinate reference system (crs)
    # to WGS84 (e.g. into the familiar Lat/Lon pairs). WGS84 is analogous
    # to EPSG:4326
    dst_crs = 'EPSG:4326'

    with rasterio.open(geotiff_path) as src:
        transform, width, height = rasterio.warp.calculate_default_transform(
            src.crs, dst_crs, src.width, src.height, *src.bounds)
        kwargs = src.meta.copy()
        kwargs.update({
            'crs': dst_crs,
            'transform': transform,
            'width': width,
            'height': height
        })

        satellite_img_name = get_file_name(geotiff_path)
        out_file_name = "{}_wgs84.tif".format(satellite_img_name)
        out_path = os.path.join(WGS84_DIR, out_file_name)
        with rasterio.open(out_path, 'w', **kwargs) as dst:
            for i in range(1, src.count + 1):
                rasterio.warp.reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=rasterio.warp.Resampling.nearest)

        return rasterio.open(out_path), out_path 
Example #5
Source File: local_io.py    From eo-learn with MIT License 5 votes vote down vote up
def _get_bands_subset(self, array):
        """ Reduce array by selecting a subset of bands
        """
        if self.band_indices is None:
            return array
        if isinstance(self.band_indices, list):
            if [band for band in self.band_indices if not isinstance(band, int)]:
                raise ValueError('Invalid format in {} list, expected integers'.format(self.band_indices))
            return array[..., self.band_indices]
        if isinstance(self.band_indices, tuple):
            if tuple(map(type, self.band_indices)) != (int, int):
                raise ValueError('Invalid format in {} tuple, expected integers'.format(self.band_indices))
            return array[..., self.band_indices[0]: self.band_indices[1] + 1]

        raise ValueError('Invalid format in {}, expected tuple or list'.format(self.band_indices)) 
Example #6
Source File: statistics_rasterInpolygon.py    From python-urbanPlanning with MIT License 5 votes vote down vote up
def reprojectedRaster(rasterFn,ref_vectorFn,dst_raster_projected):
    dst_crs=gpd.read_file(ref_vectorFn).crs
    print(dst_crs) #{'init': 'epsg:4326'}    
    a_T = datetime.datetime.now()
    
    # dst_crs='EPSG:4326'
    with rasterio.open(rasterFn) as src:
        transform, width, height = calculate_default_transform(src.crs, dst_crs, src.width, src.height, *src.bounds)
        kwargs = src.meta.copy()
        kwargs.update({
        'crs': dst_crs,
        'transform': transform,
        'width': width,
        'height': height,
        # 'compress': "LZW",
        'dtype':rasterio.uint8,  #rasterio.float32
        })
        # print(src.count)

        with rasterio.open(dst_raster_projected, 'w', **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=Resampling.nearest
                    )     
    
    b_T = datetime.datetime.now()
    print("reprojected time span:", b_T-a_T)    
 
#根据Polgyon统计raster栅格信息 
Example #7
Source File: distanceWeightCalculation_raster2Polygon.py    From python-urbanPlanning with MIT License 5 votes vote down vote up
def reprojectedRaster(rasterFn,ref_vectorFn):
    dst_crs=gpd.read_file(ref_vectorFn).crs
    print(dst_crs) #{'init': 'epsg:4326'}
    dst_raster_projected=os.path.join(dataFp_1,r"svf_dstRasterProjected_b.tif")
    a_T = datetime.datetime.now()
    
    # dst_crs='EPSG:4326'
    with rasterio.open(rasterFn) as src:
        transform, width, height = calculate_default_transform(src.crs, dst_crs, src.width, src.height, *src.bounds)
        kwargs = src.meta.copy()
        kwargs.update({
        'crs': dst_crs,
        'transform': transform,
        'width': width,
        'height': height,
        # 'compress': "LZW",
        'dtype':rasterio.float32,
        })
        # print(src.count)

        with rasterio.open(dst_raster_projected, 'w', **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=Resampling.nearest
                    )     
    
    b_T = datetime.datetime.now()
    print("reprojected time span:", b_T-a_T)    
 

#根据Polgyon统计raster栅格信息 
Example #8
Source File: preprocessing.py    From Pyspatialml with GNU General Public License v3.0 5 votes vote down vote up
def _grid_distance(shape, rows, cols):
    """Generate buffer distances to x,y coordinates.

    Parameters
    ----------
    shape : tuple
        shape of numpy array (rows, cols) to create buffer distances within.

    rows : 1d numpy array
        array of row indexes.

    cols : 1d numpy array
        array of column indexes.

    Returns
    -------
    ndarray
        3d numpy array of euclidean grid distances to each x,y coordinate pair
        [band, row, col].
    """

    # create buffer distances
    grids_buffers = np.zeros((shape[0], shape[1], rows.shape[0]),
                             dtype=np.float32)

    for i, (y, x) in enumerate(zip(rows, cols)):
        # create 2d array (image) with pick indexes set to z
        point_arr = np.zeros((shape[0], shape[1]))
        point_arr[y, x] = 1
        buffer = ndimage.morphology.distance_transform_edt(1 - point_arr)
        grids_buffers[:, :, i] = buffer

    # reorder to band, row, column
    grids_buffers = grids_buffers.transpose((2, 0, 1))

    return grids_buffers 
Example #9
Source File: rasterlayer.py    From Pyspatialml with GNU General Public License v3.0 5 votes vote down vote up
def _extract_by_indices(self, rows, cols):
        """Spatial query of Raster object (by-band)
        """

        X = np.ma.zeros((len(rows), self.count), dtype="float32")
        arr = self.read(masked=True)
        X[:, 0] = arr[rows, cols]

        return X 
Example #10
Source File: rasterlayer.py    From Pyspatialml with GNU General Public License v3.0 5 votes vote down vote up
def _write(self, arr, file_path=None, driver="GTiff", dtype=None, nodata=None):
        """Internal function to write processed results to a file, usually a tempfile
        """

        # generate a file path to a temporary file is file_path is None
        file_path, tfile = _file_path_tempfile(file_path)

        if dtype is None:
            dtype = self.dtype

        if nodata is None:
            nodata = _get_nodata(dtype)

        meta = self.ds.meta
        meta["driver"] = driver
        meta["nodata"] = nodata
        meta["dtype"] = dtype

        # mask any nodata values
        arr = np.ma.masked_equal(arr, self.nodata)
        
        # replace masked values with the user-specified nodata value
        arr = arr.filled(fill_value=nodata)

        # write to file
        with rasterio.open(file_path, mode="w", **meta) as dst:
            dst.write(arr.astype(dtype), 1)

        src = rasterio.open(file_path)
        band = rasterio.band(src, 1)
        layer = pyspatialml.RasterLayer(band)

        # override RasterLayer close method if temp file is used
        if tfile is not None:
            layer.close = tfile.close

        return layer 
Example #11
Source File: raster.py    From Pyspatialml with GNU General Public License v3.0 4 votes vote down vote up
def _probfun(img, estimator):
        """Class probabilities function.

        Parameters
        ----------
        img : tuple (window, numpy.ndarray)
            A window object, and a 3d ndarray of raster data with the dimensions in
            order of (band, rows, columns).

        estimator : estimator object implementing 'fit'
            The object to use to fit the data.

        Returns
        -------
        numpy.ndarray
            Multi band raster as a 3d numpy array containing the probabilities
            associated with each class. ndarray dimensions are in the order of
            (class, row, column).
        """
        window, img = img

        if not isinstance(img, pd.DataFrame):
            # reshape each image block matrix into a 2D matrix
            # first reorder into rows, cols, bands (transpose)
            # then resample into 2D array (rows=sample_n, cols=band_values)
            n_features, rows, cols = img.shape[0], img.shape[1], img.shape[2]
            mask2d = img.mask.any(axis=0)
            n_samples = rows * cols
            flat_pixels = img.transpose(1, 2, 0).reshape((n_samples, n_features))
            flat_pixels = flat_pixels.filled(0)
        
        else:
            flat_pixels = img
            mask2d = pd.isna(flat_pixels).values
            mask2d = mask2d.reshape((window.height, window.width, flat_pixels.shape[1]))
            mask2d = mask2d.any(axis=2)
            flat_pixels = flat_pixels.fillna(0)
            flat_pixels = flat_pixels.values

        # predict probabilities
        result_proba = estimator.predict_proba(flat_pixels)

        # reshape class probabilities back to 3D image [iclass, rows, cols]
        result_proba = result_proba.reshape((window.height, window.width, result_proba.shape[1]))

        # reshape band into rasterio format [band, row, col]
        result_proba = result_proba.transpose(2, 0, 1)

        # repeat mask for n_bands
        mask3d = np.repeat(
            a=mask2d[np.newaxis, :, :], repeats=result_proba.shape[0], axis=0
        )

        # convert proba to masked array
        result_proba = np.ma.masked_array(result_proba, mask=mask3d, fill_value=np.nan)

        return result_proba 
Example #12
Source File: raster.py    From Pyspatialml with GNU General Public License v3.0 4 votes vote down vote up
def _predfun_multioutput(img, estimator):
        """Multi-target prediction function.

        Parameters
        ----------
        img : tuple (window, numpy.ndarray)
            A window object, and a 3d ndarray of raster data with the dimensions in
            order of (band, rows, columns).

        estimator : estimator object implementing 'fit'
            The object to use to fit the data.

        Returns
        -------
        numpy.ndarray
            3d numpy array representing the multi-target prediction result with the
            dimensions in the order of (target, row, column).
        """
        window, img = img

        if not isinstance(img, pd.DataFrame):
            # reshape each image block matrix into a 2D matrix
            # first reorder into rows, cols, bands(transpose)
            # then resample into 2D array (rows=sample_n, cols=band_values)
            n_features, rows, cols = img.shape[0], img.shape[1], img.shape[2]
            mask2d = img.mask.any(axis=0)
            n_samples = rows * cols
            flat_pixels = img.transpose(1, 2, 0).reshape((n_samples, n_features))
            flat_pixels = flat_pixels.filled(0)
        
        else:
            flat_pixels = img
            mask2d = pd.isna(flat_pixels).values
            mask2d = mask2d.reshape((window.height, window.width, flat_pixels.shape[1]))
            mask2d = mask2d.any(axis=2)
            flat_pixels = flat_pixels.fillna(0)
            flat_pixels = flat_pixels.values

        # predict probabilities
        result = estimator.predict(flat_pixels)

        # reshape class probabilities back to 3D image [iclass, rows, cols]
        result = result.reshape((window.height, window.width, result.shape[1]))

        # reshape band into rasterio format [band, row, col]
        result = result.transpose(2, 0, 1)

        # repeat mask for n_bands
        mask3d = np.repeat(a=mask2d[np.newaxis, :, :], repeats=result.shape[0], axis=0)

        # convert proba to masked array
        result = np.ma.masked_array(result, mask=mask3d, fill_value=np.nan)

        return result 
Example #13
Source File: mbtiler.py    From rio-rgbify with MIT License 4 votes vote down vote up
def _tile_worker(tile):
    """
    For each tile, and given an open rasterio src, plus a`global_args` dictionary
    with attributes of `base_val`, `interval`, and a `writer_func`,
    warp a continous single band raster to a 512 x 512 mercator tile,
    then encode this tile into RGB.

    Parameters
    -----------
    tile: list
        [x, y, z] indices of tile

    Returns
    --------
    tile, buffer
        tuple with the input tile, and a bytearray with the data encoded into
        the format created in the `writer_func`

    """
    x, y, z = tile

    bounds = [
        c
        for i in (
            mercantile.xy(*mercantile.ul(x, y + 1, z)),
            mercantile.xy(*mercantile.ul(x + 1, y, z)),
        )
        for c in i
    ]

    toaffine = transform.from_bounds(*bounds + [512, 512])

    out = np.empty((512, 512), dtype=src.meta["dtype"])

    reproject(
        rasterio.band(src, 1),
        out,
        dst_transform=toaffine,
        dst_crs="epsg:3857",
        resampling=Resampling.bilinear,
    )

    out = data_to_rgb(out, global_args["base_val"], global_args["interval"])

    return tile, global_args["writer_func"](out, global_args["kwargs"].copy(), toaffine) 
Example #14
Source File: __init__.py    From rio-mbtiles with MIT License 4 votes vote down vote up
def process_tile(tile):
    """Process a single MBTiles tile

    Parameters
    ----------
    tile : mercantile.Tile

    Returns
    -------

    tile : mercantile.Tile
        The input tile.
    bytes : bytearray
        Image bytes corresponding to the tile.

    """
    global base_kwds, resampling, src

    # Get the bounds of the tile.
    ulx, uly = mercantile.xy(
        *mercantile.ul(tile.x, tile.y, tile.z))
    lrx, lry = mercantile.xy(
        *mercantile.ul(tile.x + 1, tile.y + 1, tile.z))

    kwds = base_kwds.copy()
    kwds['transform'] = transform_from_bounds(ulx, lry, lrx, uly,
                                              kwds['width'], kwds['height'])
    src_nodata = kwds.pop('src_nodata', None)
    dst_nodata = kwds.pop('dst_nodata', None)

    warnings.simplefilter('ignore')

    with MemoryFile() as memfile:

        with memfile.open(**kwds) as tmp:

            # determine window of source raster corresponding to the tile
            # image, with small buffer at edges
            try:
                west, south, east, north = transform_bounds(TILES_CRS, src.crs, ulx, lry, lrx, uly)
                tile_window = window_from_bounds(west, south, east, north, transform=src.transform)
                adjusted_tile_window = Window(
                    tile_window.col_off - 1, tile_window.row_off - 1,
                    tile_window.width + 2, tile_window.height + 2)
                tile_window = adjusted_tile_window.round_offsets().round_shape()

                # if no data in window, skip processing the tile
                if not src.read_masks(1, window=tile_window).any():
                    return tile, None

            except ValueError:
                log.info("Tile %r will not be skipped, even if empty. This is harmless.", tile)

            reproject(rasterio.band(src, tmp.indexes),
                      rasterio.band(tmp, tmp.indexes),
                      src_nodata=src_nodata,
                      dst_nodata=dst_nodata,
                      num_threads=1,
                      resampling=resampling)

        return tile, memfile.read() 
Example #15
Source File: preprocessing.py    From Pyspatialml with GNU General Public License v3.0 4 votes vote down vote up
def rotated_coordinates(layer, n_angles=8, file_path=None, driver='GTiff'):
    """Generate 2d arrays with n_angles rotated coordinates.

    Parameters
    ----------
    layer : pyspatialml.RasterLayer, or rasterio.DatasetReader
        RasterLayer to use as a template.

    n_angles : int, optional. Default is 8
        Number of angles to rotate coordinate system by.

    file_path : str, optional. Default is None
        File path to save to the resulting Raster object. If not supplied then the
        raster is saved to a temporary file.
    
    driver : str, optional. Default is 'GTiff'
        GDAL driver to use to save raster.

    Returns
    -------
    pyspatialml.Raster
    """

    # define x and y grid dimensions
    xmin, ymin, xmax, ymax = 0, 0, layer.shape[1], layer.shape[0]
    x_range = np.arange(start=xmin, stop=xmax, step=1)
    y_range = np.arange(start=ymin, stop=ymax, step=1, dtype=np.float32)

    X_var, Y_var, _ = np.meshgrid(x_range, y_range, n_angles)
    angles = np.deg2rad(np.linspace(0, 180, n_angles, endpoint=False))
    grids_directional = X_var + np.tan(angles) * Y_var

    # reorder to band, row, col order
    grids_directional = grids_directional.transpose((2, 0, 1))

    # create new stack
    file_path, tfile = _file_path_tempfile(file_path)

    meta = deepcopy(layer.meta)
    meta['driver'] = driver
    meta['count'] = n_angles
    meta['dtype'] = grids_directional.dtype
    with rasterio.open(file_path, 'w', **meta) as dst:
        dst.write(grids_directional)
    
    new_raster = Raster(file_path)
    names = ['angle_' + str(i+1) for i in range(n_angles)]
    new_raster.rename({old: new for old, new in zip(new_raster.names, names)})

    if tfile is not None:
        for layer in new_raster.iloc:
            layer.close = tfile.close

    return new_raster 
Example #16
Source File: geo.py    From solaris with Apache License 2.0 4 votes vote down vote up
def _reproject(input_data, input_type, input_crs, target_crs, dest_path,
               resampling_method='bicubic'):

    input_crs = _check_crs(input_crs)
    target_crs = _check_crs(target_crs)
    if input_type == 'vector':
        output = input_data.to_crs(target_crs)
        if dest_path is not None:
            output.to_file(dest_path, driver='GeoJSON')

    elif input_type == 'raster':

        if isinstance(input_data, rasterio.DatasetReader):
            transform, width, height = calculate_default_transform(
                input_crs.to_wkt("WKT1_GDAL"), target_crs.to_wkt("WKT1_GDAL"),
                input_data.width, input_data.height, *input_data.bounds
            )
            kwargs = input_data.meta.copy()
            kwargs.update({'crs': target_crs.to_wkt("WKT1_GDAL"),
                           'transform': transform,
                           'width': width,
                           'height': height})

            if dest_path is not None:
                with rasterio.open(dest_path, 'w', **kwargs) as dst:
                    for band_idx in range(1, input_data.count + 1):
                        rasterio.warp.reproject(
                            source=rasterio.band(input_data, band_idx),
                            destination=rasterio.band(dst, band_idx),
                            src_transform=input_data.transform,
                            src_crs=input_data.crs,
                            dst_transform=transform,
                            dst_crs=target_crs.to_wkt("WKT1_GDAL"),
                            resampling=getattr(Resampling, resampling_method)
                        )
                output = rasterio.open(dest_path)
                input_data.close()

            else:
                output = np.zeros(shape=(height, width, input_data.count))
                for band_idx in range(1, input_data.count + 1):
                    rasterio.warp.reproject(
                        source=rasterio.band(input_data, band_idx),
                        destination=output[:, :, band_idx-1],
                        src_transform=input_data.transform,
                        src_crs=input_data.crs,
                        dst_transform=transform,
                        dst_crs=target_crs,
                        resampling=getattr(Resampling, resampling_method)
                    )

        elif isinstance(input_data, gdal.Dataset):
            if dest_path is not None:
                gdal.Warp(dest_path, input_data,
                          dstSRS='EPSG:' + str(target_crs.to_epsg()))
                output = gdal.Open(dest_path)
            else:
                raise ValueError('An output path must be provided for '
                                 'reprojecting GDAL datasets.')
    return output 
Example #17
Source File: local_io.py    From eo-learn with MIT License 4 votes vote down vote up
def execute(self, eopatch, *, filename=None):
        """ Execute method

        :param eopatch: input EOPatch
        :type eopatch: EOPatch
        :param filename: filename of tiff file or None if entire path has already been specified in `folder` parameter
        of task initialization.
        :type filename: str or None
        :return: Unchanged input EOPatch
        :rtype: EOPatch
        """
        try:
            feature = next(self.feature(eopatch))
        except ValueError as error:
            LOGGER.warning(error)

            if self.fail_on_missing:
                raise ValueError(error)
            return eopatch

        image_array = self._prepare_image_array(eopatch, feature)

        channel_count, height, width = image_array.shape

        src_crs = {'init': eopatch.bbox.crs.ogc_string()}
        src_transform = rasterio.transform.from_bounds(*eopatch.bbox, width=width, height=height)

        if self.crs:
            dst_crs = {'init': self.crs.ogc_string()}
            dst_transform, dst_width, dst_height = rasterio.warp.calculate_default_transform(
                src_crs, dst_crs, width, height, *eopatch.bbox
            )
        else:
            dst_crs = src_crs
            dst_transform = src_transform
            dst_width, dst_height = width, height

        with rasterio.open(self._get_file_path(filename, create_dir=True), 'w', driver='GTiff',
                           width=dst_width, height=dst_height,
                           count=channel_count,
                           dtype=image_array.dtype, nodata=self.no_data_value,
                           transform=dst_transform, crs=dst_crs) as dst:

            if dst_crs == src_crs:
                dst.write(image_array)
            else:
                for idx in range(channel_count):
                    rasterio.warp.reproject(
                        source=image_array[idx, ...],
                        destination=rasterio.band(dst, idx + 1),
                        src_transform=src_transform,
                        src_crs=src_crs,
                        dst_transform=dst_transform,
                        dst_crs=dst_crs,
                        resampling=rasterio.warp.Resampling.nearest
                    )

        return eopatch 
Example #18
Source File: gis.py    From oggm with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def rasterio_to_gdir(gdir, input_file, output_file_name,
                     resampling='cubic'):
    """Reprojects a file that rasterio can read into the glacier directory.

    Parameters
    ----------
    gdir : :py:class:`oggm.GlacierDirectory`
        the glacier directory
    input_file : str
        path to the file to reproject
    output_file_name : str
        name of the output file (must be in cfg.BASENAMES)
    resampling : str
        nearest', 'bilinear', 'cubic', 'cubic_spline', or one of
        https://rasterio.readthedocs.io/en/latest/topics/resampling.html
    """

    output_file = gdir.get_filepath(output_file_name)
    assert '.tif' in output_file, 'output_file should end with .tif'

    if not gdir.has_file('dem'):
        raise InvalidWorkflowError('Need a dem.tif file to reproject to')

    with rasterio.open(input_file) as src:

        kwargs = src.meta.copy()
        data = src.read(1)

        with rasterio.open(gdir.get_filepath('dem')) as tpl:

            kwargs.update({
                'crs': tpl.crs,
                'transform': tpl.transform,
                'width': tpl.width,
                'height': tpl.height
            })

            with rasterio.open(output_file, 'w', **kwargs) as dst:
                for i in range(1, src.count + 1):

                    dest = np.zeros(shape=(tpl.height, tpl.width),
                                    dtype=data.dtype)

                    reproject(
                        source=rasterio.band(src, i),
                        destination=dest,
                        src_transform=src.transform,
                        src_crs=src.crs,
                        dst_transform=tpl.transform,
                        dst_crs=tpl.crs,
                        resampling=getattr(Resampling, resampling)
                    )

                    dst.write(dest, indexes=i)