try:
    import rasterio
    has_rasterio = True
except:
    has_rasterio = False

from functools import partial
import os

import dask
from dask.array import store

import numpy as np

threads = int(os.environ.get('GBDX_THREADS', 64))
threaded_get = partial(dask.threaded.get, num_workers=threads)

class rio_writer(object):
    def __init__(self, dst):
        self.dst = dst
        
    def __setitem__(self, location, chunk):
        window = ((location[1].start, location[1].stop), 
                  (location[2].start, location[2].stop))
        self.dst.write(chunk, window=window)

def to_geotiff(arr, path='./output.tif', proj=None, spec=None, bands=None, **kwargs):
    ''' Write out a geotiff file of the image

    Args:
        path (str): path to write the geotiff file to, default is ./output.tif
        proj (str): EPSG string of projection to reproject to
        spec (str): if set to 'rgb', write out color-balanced 8-bit RGB tif
        bands (list): list of bands to export. If spec='rgb' will default to RGB bands
    
    Returns:
        str: path the geotiff was written to'''
        
    assert has_rasterio, "To create geotiff images please install rasterio" 

    try:
        img_md = arr.rda.metadata["image"]
        x_size = img_md["tileXSize"]
        y_size = img_md["tileYSize"]
    except (AttributeError, KeyError):
        x_size = kwargs.get("chunk_size", 256)
        y_size = kwargs.get("chunk_size", 256)

    try:
        tfm = kwargs['transform'] if 'transform' in kwargs else arr.affine
    except:
        tfm = None

    dtype = arr.dtype.name if arr.dtype.name != 'int8' else 'uint8' 

    if spec is not None and spec.lower() == 'rgb':
        
        assert arr.options.get('dra'), 'To write RGB geotiffs, create your image option with `dra=True`'
        if bands is None:
            bands = arr._rgb_bands
        arr = arr[bands,...].astype(np.uint8)
        dtype = 'uint8'
    else:
        if bands is not None:
            arr = arr[bands,...]
    meta = {
        'width': arr.shape[2],
        'height': arr.shape[1],
        'count': arr.shape[0],
        'dtype': dtype,
        'driver': 'GTiff',
        'transform': tfm
    }
    if proj is not None:
        meta["crs"] = {'init': proj}

    if "tiled" in kwargs and kwargs["tiled"]:
        meta.update(blockxsize=x_size, blockysize=y_size, tiled="yes")

    with rasterio.open(path, "w", **meta) as dst:
        writer = rio_writer(dst)
        result = store(arr, writer, compute=False)
        result.compute(scheduler=threaded_get)
    
    return path