from keras import backend as K
from keras.layers import Embedding
import numpy as np



class DynamicEmbedding(Embedding):
    def __init__(self, embedding_matrix, mode='matrix', *args, **kwargs):
        assert hasattr(embedding_matrix, '_keras_shape')
        self.W = embedding_matrix
        if mode=='tensor':
            assert len(embedding_matrix._keras_shape) == 3
            indim = self.W._keras_shape[1]
            outdim = self.W._keras_shape[2]
        else:
            assert len(embedding_matrix._keras_shape) == 2
            indim, outdim = self.W._keras_shape

        self.mode = mode
        kwargs['mask_zero'] = True
        super(DynamicEmbedding, self).__init__(indim, outdim, *args, **kwargs)

        #layer, node_index, tensor_index = self.W._keras_history
        #self.add_inbound_node(layer, node_index, tensor_index)
        
        
    def __call__(self, x, mask=None):
        ### hacky. 
        return super(DynamicEmbedding, self).__call__([x, self.W], mask)

    def build(self, input_shape):
        if isinstance(input_shape, list):
            input_shape,_ = input_shape

        self.constraints = {}
        if self.W_constraint:
            self.constraints[self.W] = self.W_constraint

        self.regularizers = []
        if self.W_regularizer:
            self.W_regularizer.set_param(self.W)
            self.regularizers.append(self.W_regularizer)

        if self.activity_regularizer:
            self.activity_regularizer.set_layer(self)
            self.regularizers.append(self.activity_regularizer)

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            
        self.built = True

    def compute_mask(self, x, mask=None):
        if isinstance(x, list):
            x, _ = x
        if mask is not None and isinstance(mask, list):
            mask, _ = mask
        return super(DynamicEmbedding, self).compute_mask(x, mask)

    def get_output_shape_for(self, input_shape):
        if isinstance(input_shape, list):
            input_shape,_ = input_shape
        return super(DynamicEmbedding, self).get_output_shape_for(input_shape)

    def call(self, x, mask=None):
        if isinstance(x, list): 
            x,_ = x
        if mask is not None and isinstance(mask, list):
            mask,_ = mask
        if 0. < self.dropout < 1.:
            retain_p = 1. - self.dropout
            dims = self.W._keras_shape[:-1]
            B = K.random_binomial(dims, p=retain_p) * (1. / retain_p)
            B = K.expand_dims(B)
            W = K.in_train_phase(self.W * B, self.W)
        else:
            W = self.W
        
        if self.mode == 'matrix':
            return K.gather(W,x)
        elif self.mode == 'tensor':
            # quick and dirty: only allowing for 3dim inputs when it's tensor mode
            assert K.ndim(x) == 3
            # put sequence on first; gather; take diagonal across shared batch dimension
            # in other words, W is (B, S, F)
            # incoming x is (B, S, A)
            inds = K.arange(self.W._keras_shape[0])
            #out = K.gather(K.permute_dimensions(W, (1,0,2)), x).diagonal(axis1=0, axis2=3)
            #return K.permute_dimensions(out, (3,0,1,2))
            ### method above doesn't do grads =.=
            # tensor abc goes to bac, indexed onto with xyz, goes to xyzac, 
            # x == a, so shape to xayzc == xxyzc
            # take diagonal on first two: xyzc 
            #out = K.colgather()
            out = K.gather(K.permute_dimensions(W, (1,0,2)), x) 
            out = K.permute_dimensions(out, (0,3,1,2,4))
            out = K.gather(out, (inds, inds))
            return out
        else:
            raise Exception('sanity check. should not be here.')

        #all_dims = T.arange(len(self.W._keras_shape))
        #first_shuffle = [all_dims[self.embed_dim]] + all_dims[:self.embed_dim] + all_dims[self.embed_dim+1:]
        ## 1. take diagonal from 0th to
        ## chang eof tactics
        ## embed on time or embed on batch. that's all I'm supporting.  
        ## if it's embed on time, then, x.ndim+1 is where batch will be, and is what
        ## i need to take the diagonal over. 
        ## now dim shuffle the xdims + 1 to the front.
        #todo: get second shuffle or maybe find diagonal calculations
        #out = K.gather(W, x)
        #return out

        ### reference
        #A = S(np.arange(60).reshape(3,4,5))
        #x = S(np.random.randint(0, 4, (3,4,10)))
        #x_emb = A.dimshuffle(1,0,2)[x].dimshuffle(0,3,1,2,4)[T.arange(A.shape[0]), T.arange(A.shape[0])]