#! /usr/bin/env python
# -*- coding: utf-8 -*-

__author__ = 'maxim'

import math
from scipy import stats

from .nodes import *

def wrap(node, transform):
  if transform is not None:
    return MergeNode(transform, node)
  return node

def uniform(start=0.0, end=1.0, transform=None, name=None):
  node = UniformNode(start, end).with_name(name)
  return wrap(node, transform)

def normal(mean=0.0, stdev=1.0, name=None):
  return NonUniformNode(ppf=stats.norm.ppf, loc=mean, scale=stdev).with_name(name)

def choice(array, transform=None, name=None):
  if not [item for item in array if isinstance(item, BaseNode)]:
    node = ChoiceNode(*array).with_name(name)
  else:
    node = MergeChoiceNode(*array).with_name(name)
  return wrap(node, transform)

def merge(nodes, function, name=None):
  if callable(nodes) and not callable(function):
    nodes, function = function, nodes
  if isinstance(nodes, BaseNode):
    nodes = [nodes]
  return MergeNode(function, *nodes).with_name(name)

def random_bit():
  return choice([0, 1])

def random_bool():
  return choice([False, True])

def random_int(n):
  return choice(range(n))

def exp(node): return merge([node], math.exp)
def expm1(node): return merge([node], math.expm1)
def frexp(node): return merge([node], math.frexp)
def ldexp(node, i): return merge([node], lambda x: math.ldexp(x, i))

def sqrt(node): return merge([node], math.sqrt)
def pow(a, b): return a ** b

def log(node, base=None): return merge([node], lambda x: math.log(x, base))
def log1p(node): return merge([node], math.log1p)
def log10(node): return merge([node], math.log10)

def sin(node): return merge([node], math.sin)
def cos(node): return merge([node], math.cos)
def tan(node): return merge([node], math.tan)

def sinh(node): return merge([node], math.sinh)
def cosh(node): return merge([node], math.cosh)
def tanh(node): return merge([node], math.tanh)

def asin(node): return merge([node], math.asin)
def acos(node): return merge([node], math.acos)
def atan(node): return merge([node], math.atan)
def atan2(node): return merge([node], math.atan2)

def asinh(node): return merge([node], math.asinh)
def acosh(node): return merge([node], math.acosh)
def atanh(node): return merge([node], math.atanh)

def min_(*array):
  nodes = [item for item in array if isinstance(item, BaseNode)]
  if len(nodes) == 0:
    return min(*array) if len(array) > 1 else array[0]
  node = merge(nodes, min) if len(nodes) > 1 else nodes[0]

  rest = [item for item in array if not isinstance(item, BaseNode)]
  if rest:
    node = merge([node], lambda x: min(x, *rest))
  return node

def max_(*array):
  nodes = [item for item in array if isinstance(item, BaseNode)]
  if len(nodes) == 0:
    return max(*array) if len(array) > 1 else array[0]
  node = merge(nodes, max) if len(nodes) > 1 else nodes[0]

  rest = [item for item in array if not isinstance(item, BaseNode)]
  if rest:
    node = merge([node], lambda x: max(x, *rest))
  return node

def new(*args, **kwargs):
  from ..base import NamedDict
  if len(args) == 1 and len(kwargs) == 0:
    return NamedDict(args[0])
  assert len(args) == 0, 'Failed to created a NamedDict with arguments: %s' % str(args)
  return NamedDict(kwargs)