from typing import Tuple, Any, Dict, Type
import tensorflow as tf
from tensorflow import Tensor

from decompose.distributions.distribution import Distribution
from decompose.distributions.distribution import ParameterInfo
from decompose.distributions.nnNormal import NnNormal
from decompose.distributions.algorithms import Algorithms
from decompose.distributions.cenNnNormalAlgorithms import CenNnNormalAlgorithms
from decompose.distributions.distribution import Properties


class CenNnNormal(NnNormal):
    def __init__(self,
                 algorithms: Type[Algorithms] = CenNnNormalAlgorithms,
                 tau: Tensor = None,
                 properties: Properties = None) -> None:
        parameters = {"tau": tau}
        Distribution.__init__(self,
                              algorithms=algorithms,
                              parameters=parameters,
                              properties=properties)

    def parameterInfo(self,
                      shape: Tuple[int, ...] = (1,),
                      latentShape: Tuple[int, ...] = ()) -> ParameterInfo:
        initializers = {
            "tau": (shape, True)
        }  # type: Dict[str, Tensor]
        return(initializers)

    @property
    def mu(self) -> Tensor:
        mu = tf.zeros_like(self.tau)
        return(mu)