import numpy as np
import tensorflow as tf

def concrete_shape(shape):
    if isinstance(shape, tuple):
        return shape
    elif isinstance(shape, tf.TensorShape):
        return tuple([d.value for d in shape])
    else:
        raise Exception("don't know how to interpret %s as a tensor shape" % shape)
    
def extract_shape(t):

    if t.get_shape():
        pass
    shape = tuple([d.value for d in t.get_shape()])
    if len(shape)==1 and shape[0] is None:
        shape = ()
    return shape

def shapes_equal(shape1, shape2):
    ns1 = np.asarray(shape1)
    ns2 = np.asarray(shape2)

    try:
        return (ns1 == ns2).all()
    except:
        return False

def shape_is_scalar(shape):
    ns = np.asarray(shape)
    return (ns == (1,)).all()

def logsumexp(x1, x2):
    shift = tf.maximum(x1, x2)
    return tf.log(tf.exp(x1 - shift) + tf.exp(x2-shift)) + shift

def reduce_logsumexp(x, **kwargs):
    shift = tf.reduce_max(x, **kwargs)
    return tf.log(tf.reduce_sum(tf.exp(x - shift), **kwargs)) + shift

def triangular_inv(L):
    eye = tf.diag(tf.ones_like(tf.diag_part(L)))
    invL = tf.matrix_triangular_solve(L, eye)
    return invL

def broadcast_shape(**shapes):
    result = None
    xs = [np.empty(shape) for shape in shapes.values()]
    return np.broadcast(*xs).shape
    

def differentiable_sq_singular_vals(A):
    # the standard tensorflow SVD repr isn't differentiable,
    # but if we fix the rotation we can get an approximate
    # derivative
    d, u, v = tf.svd(A)
    vv = tf.stop_gradient(v)
    ud = tf.matmul(A, vv)
    
    dd = tf.reduce_sum(tf.square(ud), 0)
    return dd