#!/usr/bin/env python
# leike_ensslin_2019.py
# The Leike & Ensslin (2019) dust map.
# Copyright (C) 2019  Gregory M. Green
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

from __future__ import print_function, division

import numpy as np
import h5py
import astropy.coordinates as coordinates
import astropy.units as units
from astropy.coordinates import Longitude

from .map_base import DustMap, ensure_flat_coords
from .std_paths import *
from . import fetch_utils

class LeikeEnsslin2019Query(DustMap):
    A class for querying the Leike & Ensslin (2019) dust map.

    def __init__(self, map_fname=None):
            map_fname (Optional[str]): Filename of the map. Defaults
                to :obj:`None`, meaning that the default location
                is used.

        if map_fname is None:
            map_fname = os.path.join(

        self._data = {}

        with h5py.File(map_fname) as f:
            self._data['mean'] = f['mean'][:]
            self._data['std'] = f['std'][:]

        self._shape = self._data['mean'].shape

    def _coords2idx(self, coords):
        c = coords.transform_to('galactic').represent_as('cartesian')
        idx = np.empty((3,) + c.shape, dtype='i4')
        mask = np.zeros(c.shape, dtype=np.bool)

        for i,x in enumerate((c.x, c.y, c.z)):
            idx[i,...] = np.floor(x.to('pc').value + 300) * 256/600.
            mask |= (idx[i] < 0) | (idx[i] >= self._shape[i])

        for i in range(3):
            idx[i, mask] = -1

        return idx, mask

    def query(self, coords, component='mean'):
        Returns the extinction density (in e-foldings / kpc, in Gaia G-band)
        at the given coordinates.

            coords (:obj:`astropy.coordinates.SkyCoord`): Coordinates at which
                to query the extinction. Must be 3D (i.e., include distance
            component (str): Which component to return. Allowable values are
                'mean' (for the mean extinction density) and 'std' (for the
                standard deviation of extinction density). Defaults to 'mean'.

            The extinction density, in units of e-foldings / pc, as either a
            numpy array or float, with the same shape as the input
        idx,mask = self._coords2idx(coords)

        v = self._data[component][idx[0], idx[1], idx[2]]
        if np.any(mask):
            # Set extinction to NaN for out-of-bounds (x, y, z)
            v[mask] = np.nan

        return v

def fetch(clobber=False):
    Downloads the 3D dust map of Leike & Ensslin (2019).

        clobber (Optional[bool]): If ``True``, any existing file will be
            overwritten, even if it appears to match. If ``False`` (the
            default), ``fetch()`` will attempt to determine if the dataset
            already exists. This determination is not 100\% robust against data
    dest_dir = fname_pattern = os.path.join(data_dir(), 'leike_ensslin_2019')
    fname = os.path.join(dest_dir, 'simple_cube.h5')
    # Check if the FITS table already exists
    md5sum = 'f54e01c253453117e3770575bed35078'

    if (not clobber) and fetch_utils.check_md5sum(fname, md5sum):
        print('File appears to exist already. Call `fetch(clobber=True)` '
              'to force overwriting of existing file.')

    # Download from the server
    url = 'https://zenodo.org/record/2577337/files/simple_cube.h5?download=1'
    fetch_utils.download_and_verify(url, md5sum, fname)