from typing import Union, ByteString, Sequence, List, AnyStr, Callable
import numpy as np


dtype_dict = {
    'float': 'FLOAT',
    'double': 'DOUBLE',
    'float32': 'FLOAT',
    'float64': 'DOUBLE',
    'int8': 'INT8',
    'int16': 'INT16',
    'int32': 'INT32',
    'int64': 'INT64',
    'uint8': 'UINT8',
    'uint16': 'UINT16',
    'uint32': 'UINT32',
    'uint64': 'UINT64'}

allowed_devices = {'CPU', 'GPU'}
allowed_backends = {'TF', 'TFLITE', 'TORCH', 'ONNX'}


def numpy2blob(tensor: np.ndarray) -> tuple:
    """Convert the numpy input from user to `Tensor`."""
    try:
        dtype = dtype_dict[str(tensor.dtype)]
    except KeyError:
        raise TypeError(f"RedisAI doesn't support tensors of type {tensor.dtype}")
    shape = tensor.shape
    blob = bytes(tensor.data)
    return dtype, shape, blob


def blob2numpy(value: ByteString, shape: Union[list, tuple], dtype: str) -> np.ndarray:
    """Convert `BLOB` result from RedisAI to `np.ndarray`."""
    mm = {
        'FLOAT': 'float32',
        'DOUBLE': 'float64'
    }
    dtype = mm.get(dtype, dtype.lower())
    a = np.frombuffer(value, dtype=dtype)
    return a.reshape(shape)


def list2dict(lst):
    """Convert the list from RedisAI to a dict."""
    if len(lst) % 2 != 0:
        raise RuntimeError("Can't unpack the list: {}".format(lst))
    out = {}
    for i in range(0, len(lst), 2):
        key = lst[i].decode().lower()
        val = lst[i + 1]
        if key != 'blob' and isinstance(val, bytes):
            val = val.decode()
        out[key] = val
    return out


def recursive_bytetransform(arr: List[AnyStr], target: Callable) -> list:
    """
    Recurse value, replacing each element of b'' with the appropriate element.

    Function returns the same array after inplace operation which updates `arr`
    """
    for ix in range(len(arr)):
        obj = arr[ix]
        if isinstance(obj, list):
            recursive_bytetransform(obj, target)
        else:
            arr[ix] = target(obj)
    return arr


def listify(inp: Union[str, Sequence[str]]) -> Sequence[str]:
    """Wrap the ``inp`` with a list if it's not a list already."""
    return (inp,) if not isinstance(inp, (list, tuple)) else inp