"""Definitions for the primitive `array_reduce`."""

import numpy as np

from ..lib import (
    ANYTHING,
    SHAPE,
    TYPE,
    AbstractArray,
    AbstractFunctionBase,
    MetaGraph,
    MyiaShapeError,
    bprop_to_grad_transform,
    build_value,
    force_pending,
    newenv,
    standard_prim,
    u64tup_typecheck,
)
from ..operations import distribute, shape, zeros_like
from . import primitives as P


def pyimpl_array_reduce(fn, array, shp):
    """Implement `array_reduce`."""
    idtype = array.dtype
    ufn = np.frompyfunc(fn, 2, 1)
    delta = len(array.shape) - len(shp)
    if delta < 0:
        raise ValueError("Shape to reduce to cannot be larger than original")

    def is_reduction(ishp, tshp):
        if tshp == 1 and ishp > 1:
            return True
        elif tshp != ishp:
            raise ValueError("Dimension mismatch for reduce")
        else:
            return False

    reduction = [
        (delta + idx if is_reduction(ishp, tshp) else None, True)
        for idx, (ishp, tshp) in enumerate(zip(array.shape[delta:], shp))
    ]

    reduction = [(i, False) for i in range(delta)] + reduction

    for idx, keep in reversed(reduction):
        if idx is not None:
            array = ufn.reduce(array, axis=idx, keepdims=keep)

    if not isinstance(array, np.ndarray):
        # Force result to be ndarray, even if it's 0d
        array = np.array(array)

    array = array.astype(idtype)

    return array


def debugvm_array_reduce(vm, fn, array, shp):
    """Implement `array_reduce` for the debug VM."""

    def fn_(a, b):
        return vm.call(fn, [a, b])

    return pyimpl_array_reduce(fn_, array, shp)


@standard_prim(P.array_reduce)
async def infer_array_reduce(
    self,
    engine,
    fn: AbstractFunctionBase,
    a: AbstractArray,
    shp: u64tup_typecheck,
):
    """Infer the return type of primitive `array_reduce`."""
    shp_i = await force_pending(a.xshape())
    shp_v = build_value(shp, default=ANYTHING)
    if shp_v == ANYTHING:
        raise AssertionError(
            "We currently require knowing the shape for reduce."
        )
        # return (ANYTHING,) * (len(shp_i) - 1)
    else:
        delta = len(shp_i) - len(shp_v)
        if delta < 0 or any(
            1 != s1 != ANYTHING and 1 != s2 != ANYTHING and s1 != s2
            for s1, s2 in zip(shp_i[delta:], shp_v)
        ):
            raise MyiaShapeError(
                f"Incompatible dims for reduce: {shp_i}, {shp_v}"
            )

    res = await engine.execute(fn, a.element, a.element)
    return type(a)(res, {SHAPE: shp_v, TYPE: a.xtype()})


def bprop_sum(fn, xs, shp, out, dout):  # pragma: no cover
    """Backpropagator for sum(xs) = array_reduce(scalar_add, xs, shp)."""
    return (newenv, distribute(dout, shape(xs)), zeros_like(shp))


class ArrayReduceGradient(MetaGraph):
    """Generate the gradient graph for array_reduce.

    For the time being, the gradient of array_reduce is only supported
    over the `scalar_add` operation (sum, basically).
    """

    def generate_graph(self, args):
        """Generate the gradient graph."""
        # BUG: We only support gradients for sum (scalar_add as the reduction
        # function). However, it is currently not possible to check what the
        # reduction function is, due to erasure of the information when
        # GraphFunction and so on are converted to VirtualFunction.
        return bprop_to_grad_transform(P.array_reduce)(bprop_sum)


__operation_defaults__ = {
    "name": "array_reduce",
    "registered_name": "array_reduce",
    "mapping": P.array_reduce,
    "python_implementation": pyimpl_array_reduce,
}


__primitive_defaults__ = {
    "name": "array_reduce",
    "registered_name": "array_reduce",
    "type": "backend",
    "python_implementation": pyimpl_array_reduce,
    "debugvm_implementation": debugvm_array_reduce,
    "inferrer_constructor": infer_array_reduce,
    "grad_transform": ArrayReduceGradient(name="array_reduce_gradient"),
}