from __future__ import print_function
import numpy as np
from scipy.optimize import curve_fit, minimize
from sklearn.decomposition import FactorAnalysis
from copy import deepcopy
import warnings

"""
Zero-inflated factor analysis (ZIFA). Performs dimensionality reduction on zero-inflated data.
Created by Emma Pierson and Christopher Yau.

Sample usage:
Z, model_params = fitModel(Y, k)

where Y is the observed zero-inflated data, k is the desired number of latent dimensions, and Z is the low-dimensional projection.

See example.py for a full example.
"""


def mult_diag(d, mtx, left=True):
	"""
	Multiply a full matrix by a diagonal matrix.
	This function should always be faster than dot.
	Input:
	d -- 1D (N,) array (contains the diagonal elements)
	mtx -- 2D (N,N) array
	Returns:
	mult_diag(d, mts, left=True) == dot(diag(d), mtx)
	mult_diag(d, mts, left=False) == dot(mtx, diag(d))
	"""
	if left:
		return (d * mtx.T).T
	else:
		return d * mtx


def checkNoNans(matrix_list):
	"""
	Returns false if any of the matrices are nans or infinite.
	Checked.
	"""
	for i, M in enumerate(matrix_list):
		if np.any(np.isnan(np.array(M))) or np.any(np.isinf(np.array(M))):
			raise Exception('Matrix index %i in list has a NaN or infinite element' % i)


def invertFast(A, d):
	"""
	given an array A of shape d x k and a d x 1 vector d, computes (A * A.T + diag(d)) ^{-1}
	Checked.
	"""
	assert(A.shape[0] == d.shape[0])
	assert(d.shape[1] == 1)

	k = A.shape[1]
	A = np.array(A)
	d_vec = np.array(d)
	d_inv = np.array(1 / d_vec[:, 0])

	inv_d_squared = np.dot(np.atleast_2d(d_inv).T, np.atleast_2d(d_inv))
	M = np.diag(d_inv) - inv_d_squared * np.dot(np.dot(A, np.linalg.inv(np.eye(k, k) + np.dot(A.T, mult_diag(d_inv, A)))), A.T)

	return M


def Estep(Y, A, mus, sigmas, decay_coef):
	"""
	estimates the requisite latent expectations in the E-step.
	Checked.
	Input:
	Y: observed data.
	A, mus, sigmas, decay_coef: parameters.
	Returns:
	EZ, EZZT, EX, EXZ, EX2: latent expectations.
	"""
	assert(len(Y[0]) == len(A))

	N, D = Y.shape
	D, K = A.shape
	assert((sigmas.shape[0] == D) and (sigmas.shape[1] == 1))
	assert((mus.shape[0] == D) and (mus.shape[1] == 1))

	EX = np.zeros([N, D])
	EXZ = np.zeros([N, D, K])  # this is a 3D tensor.
	EX2 = np.zeros([N, D])
	EZ = np.zeros([N, K])
	EZZT = np.zeros([N, K, K])

	for i in range(N):
		# compute P(Z, X_0 | Y) following three step formula.
		# 1. compute P(Z, X_0)
		Y_i = Y[i, :]
		Y_is_zero = np.abs(Y_i) < 1e-6
		dim = K + D  # this is dimension of matrix

		# 2. compute P(Z, X_0 | Y_+)
		mu_c, sigma_c, augmentedA_0, augmentedA_plus, augmented_D, sigma_22_inv = calcConditionalDistribution(A, mus, sigmas, np.array([np.abs(Y_i[j]) < 1e-6 for j in range(D)]), Y_i[~Y_is_zero])

		# 3. compute P(Z, X_0 | Y_+, Y_0)
		dim = len(sigma_c)
		matrixToInvert = computeMatrixInLastStep(A, np.abs(Y[i, :]) < 1e-6, sigmas, K, sigma_c, decay_coef, sigma_22_inv)

		if (Y_is_zero).sum() < D:
			magical_matrix = 2 * decay_coef * (mult_diag(augmented_D, matrixToInvert) + augmentedA_0 * (np.eye(K) - augmentedA_plus.T * sigma_22_inv * augmentedA_plus) * (augmentedA_0.T * matrixToInvert))
		else:
			magical_matrix = 2 * decay_coef * (mult_diag(augmented_D, matrixToInvert) + augmentedA_0 * (augmentedA_0.T * matrixToInvert))
		magical_matrix[:, :K] = 0

		if (Y_is_zero).sum() < D:
			sigma_xz = np.array(sigma_c - mult_diag(augmented_D, np.array(magical_matrix), left=False) - (magical_matrix * augmentedA_0) * ((np.eye(K) - augmentedA_plus.T * sigma_22_inv * augmentedA_plus) * augmentedA_0.T))
		else:
			sigma_xz = np.array(sigma_c - mult_diag(augmented_D, np.array(magical_matrix), left=False) - (magical_matrix * augmentedA_0) * augmentedA_0.T)

		mu_xz = np.array(np.matrix(np.eye(dim) - magical_matrix) * np.matrix(mu_c))

		EZ[i, :] = mu_xz[:K, 0]
		EX[i, Y_is_zero] = mu_xz[K:, 0]
		EX2[i, Y_is_zero] = mu_xz[K:, 0] ** 2 + np.diag(sigma_xz[K:, K:])
		EZZT[i, :, :] = np.dot(np.atleast_2d(mu_xz[:K, :]), np.atleast_2d(mu_xz[:K, :].transpose())) + sigma_xz[:K, :K]
		EXZ[i, Y_is_zero, :] = np.dot(mu_xz[K:], mu_xz[:K].transpose()) + sigma_xz[K:, :K]

	return EZ, EZZT, EX, EXZ, EX2


def applyWoodburyIdentity(A_inv, B_inv, C):
	"""
	uses Woodbury identity to compute inverse of (A + C * B * C.T)
	"""
	A_inv = np.matrix(A_inv)
	B_inv = np.matrix(B_inv)
	C = np.matrix(C)

	A_inv_C = (A_inv * C)
	M = A_inv - A_inv_C * np.linalg.inv(B_inv + C.T * A_inv_C) * A_inv_C.T
	return M


def computeMatrixInLastStep(A, zero_indices, sigmas, K, sigma_c, decay_coef, sigma_22_inv):
	"""
	Optimized matrix method for the E-step.
	This computes (1 + 2 * decay_coef * I_x * sigma_c) ^ {-1}
	zero_indices should have length D.
	"""
	A_0 = np.matrix(A[zero_indices, :])
	A_plus = np.matrix(A[~zero_indices, :])
	sigmas_0 = sigmas[zero_indices]

	E_xz = sigma_c[K:, :][:, :K]
	E_00_prime_inv = np.matrix(invertFast(A_0, sigmas_0 ** 2 + 1 / (2. * decay_coef)))
	E_plusplus_inv = sigma_22_inv
	E_0plus = A_0 * A_plus.T

	if (E_plusplus_inv.shape[0] == 0) or (E_plusplus_inv.shape[1] == 0):
		inv_matrix = (1 / (2. * decay_coef)) * E_00_prime_inv

	elif (A_0.shape[0] < A_0.shape[1]):
		inv_matrix = np.linalg.inv(2. * decay_coef * (np.linalg.inv(E_00_prime_inv) - E_0plus * E_plusplus_inv * E_0plus.T))

	else:
		b_inv = np.linalg.inv((np.matrix(A_0).T * E_00_prime_inv) * np.matrix(A_0))
		innermost_inverse = applyWoodburyIdentity(-E_plusplus_inv, b_inv, np.matrix(A_plus))
		inv_matrix = (1 / (2. * decay_coef)) * (E_00_prime_inv - (E_00_prime_inv * A_0) * (A_plus.T * innermost_inverse * A_plus) * (A_0.T * E_00_prime_inv))

	dim = len(sigma_c)
	M = np.zeros([dim, dim])
	M[:K, :K] = np.eye(K)
	M[K:, :K] = -2 * decay_coef * np.dot(inv_matrix, E_xz)
	M[K:, K:] = inv_matrix

	return np.array(M)


def Mstep(Y, EZ, EZZT, EX, EXZ, EX2, oldA, old_mus, old_sigmas, old_decay_coef, singleSigma=False):
	"""
	estimates parameters given the expectations computed in the E-step.
	Input:
	Y: observed values.
	EZ, EZZT, EX, EXZ, EX2: expectations of latent variables computed in E-step.
	oldA, old_mus, old_sigmas, old_decay_coef: old parameters.
	singleSigma: only estimates one sigma as opposed to sigma_j for all j.
	Returns:
	A, mus, sigmas, decay_coef: new values of parameters.
	"""
	assert(len(Y) == len(EZ))
	N, D = Y.shape
	N, K = EZ.shape

	# First estimate A and mu
	A = np.zeros([D, K])
	mus = np.zeros([D, 1])
	sigmas = np.zeros([D, 1])
	Y_is_zero = np.abs(Y) < 1e-6

	# Make B, which is the same for all J.
	B = np.eye(K + 1)
	for k1 in range(K):
		for k2 in range(K):
			B[k1][k2] = sum(EZZT[:, k1, k2])
		B[K, :K] = EZ.sum(axis=0)
		B[:K, K] = EZ.sum(axis=0)

	B[K, K] = N

	tiled_EZ = np.tile(np.resize(EZ, [N, 1, K]), [1, D, 1])
	tiled_Y = np.tile(np.resize(Y, [N, D, 1]), [1, 1, K])
	tiled_Y_is_zero = np.tile(np.resize(Y_is_zero, [N, D, 1]), [1, 1, K])

	c = np.zeros([K + 1, D])
	c[K, :] += (Y_is_zero * EX + (1 - Y_is_zero) * Y).sum(axis=0)
	c[:K, :] = (tiled_Y_is_zero * EXZ + (1 - tiled_Y_is_zero) * tiled_Y * tiled_EZ).sum(axis=0).transpose()

	solution = np.dot(np.linalg.inv(B), c)
	A = solution[:K, :].transpose()
	mus = np.atleast_2d(solution[K, :]).transpose()

	# Then optimize sigma
	EXM = np.zeros([N, D])  # have to figure these out  after updating mu.
	EM = np.zeros([N, D])
	EM2 = np.zeros([N, D])

	tiled_mus = np.tile(mus.transpose(), [N, 1])
	tiled_A = np.tile(np.resize(A, [1, D, K]), [N, 1, 1])

	EXM = (tiled_A * EXZ).sum(axis=2) + tiled_mus * EX
	test_sum = (tiled_A * tiled_EZ).sum(axis=2)
	A_product = np.tile(np.reshape(A, [1, D, K]), [K, 1, 1]) * (np.tile(np.reshape(A, [1, D, K]), [K, 1, 1]).T)

	for i in range(N):
		EM[i, :] = (np.dot(A, EZ[i, :].transpose()) + mus.transpose())  # this should be correct
		EZZT_tiled = np.tile(np.reshape(EZZT[i, :, :], [K, 1, K]), [1, D, 1])
		ezzt_sum = (EZZT_tiled * A_product).sum(axis=2).sum(axis=0)
		EM2[i, :] = ezzt_sum + 2 * test_sum[i, :] * tiled_mus[i, :] + tiled_mus[i, :] ** 2

	sigmas = (Y_is_zero * (EX2 - 2 * EXM + EM2) + (1 - Y_is_zero) * (Y ** 2 - 2 * Y * EM + EM2)).sum(axis=0)
	sigmas = np.atleast_2d(np.sqrt(sigmas / N)).transpose()

	if singleSigma:
		sigmas = np.mean(sigmas) * np.ones(sigmas.shape)

	decay_coef = minimize(lambda x: decayCoefObjectiveFn(x, Y, EX2), old_decay_coef, jac=True, bounds=[[1e-8, np.inf]])
	decay_coef = decay_coef.x[0]

	return A, mus, sigmas, decay_coef


def calcConditionalDistribution(A, mus, sigmas, zero_indices, observed_values):
	"""
	Computes the distribution of X and Z conditional on the NONZERO values of Y. Matrix computations optimized.
	Input:
	A, mus, sigmas are parameters.
	zero_indices has length D and zero_indices[j] is zero iff Y[j] is zero.
	observed_values are the values of Y at the nonzero indices
	Output: various matrices used in the rest of the E-step.

	"""
	D, K = A.shape
	dim = D + K

	# mu_x and sigma_x here are the PRIOR DISTRIBUTIONS over [Z, X]
	mu_x = np.zeros([dim, 1])
	mu_x[K:dim, :] = mus
	augmentedA = np.matrix(np.zeros([dim, K]))
	augmentedA[:K, :] = np.eye(K)
	augmentedA[K:, :] = A
	mu_x = np.atleast_2d(mu_x)
	observed_values = np.atleast_2d(observed_values)

	if len(observed_values) == 1:
		observed_values = observed_values.transpose()

	mu_diff = np.matrix(np.atleast_2d(observed_values - mus[~zero_indices]))

	assert(mu_diff.shape[0] == len(observed_values) and mu_diff.shape[1] == 1)
	assert(len(zero_indices) == D)
	assert(zero_indices.sum() == D - len(observed_values))

	augmented_zero_indices = np.array([True for a in range(K)] + list(zero_indices))

	# 1 denotes the zero entries, 2 denotes the nonzero entries
	augmentedA_0 = augmentedA[augmented_zero_indices, :]
	augmentedA_plus = augmentedA[~augmented_zero_indices, :]
	sigma_11 = augmentedA_0 * augmentedA_0.T
	sigma_11[K:, K:] = sigma_11[K:, K:] + np.diag(sigmas[zero_indices][:, 0] ** 2)
	augmented_D = np.array([0 for i in range(K)] + list(sigmas[zero_indices][:, 0] ** 2))

	if len(observed_values) == 0:
		sigma_x = augmentedA * augmentedA.T
		sigma_x[K:, K:] = sigma_x[K:, K:] + np.diag(sigmas[zero_indices][:, 0] ** 2)
		return mu_x, sigma_x, augmentedA_0, augmentedA_plus, augmented_D, np.array([[]])

	sigma_22_inv = np.matrix(invertFast(A[~zero_indices, :], sigmas[~zero_indices] ** 2))
	mu_0 = mu_x[augmented_zero_indices, :] + augmentedA_0 * (augmentedA_plus.T * (sigma_22_inv * mu_diff))
	sigma_0 = sigma_11 - augmentedA_0 * (augmentedA_plus.T * sigma_22_inv * augmentedA_plus) * augmentedA_0.T

	assert((mu_0.shape[0] == zero_indices.sum() + K) and (mu_0.shape[1] == 1))
	assert(sigma_0.shape[0] == zero_indices.sum() + K and sigma_0.shape[1] == zero_indices.sum() + K)

	return np.array(mu_0), np.array(sigma_0), augmentedA_0, augmentedA_plus, augmented_D, sigma_22_inv


def decayCoefObjectiveFn(x, Y, EX2):
	"""
	Computes the objective function for terms involving lambda in the M-step.
	Checked.
	Input:
	x: value of lambda
	Y: the matrix of observed values
	EX2: the matrix of values of EX2 estimated in the E-step.
	Returns:
	obj: value of objective function
	grad: gradient
	"""
	with warnings.catch_warnings():
		warnings.simplefilter("ignore")

		y_squared = Y ** 2
		Y_is_zero = np.abs(Y) < 1e-6
		exp_Y_squared = np.exp(-x * y_squared)
		log_exp_Y = np.nan_to_num(np.log(1 - exp_Y_squared))
		exp_ratio = np.nan_to_num(exp_Y_squared / (1 - exp_Y_squared))
		obj = sum(sum(Y_is_zero * (-EX2 * x) + (1 - Y_is_zero) * log_exp_Y))
		grad = sum(sum(Y_is_zero * (-EX2) + (1 - Y_is_zero) * y_squared * exp_ratio))

		if (type(obj) is not np.float64) or (type(grad) is not np.float64):
			raise Exception("Unexpected behavior in optimizing decay coefficient lambda. Please contact emmap1@cs.stanford.edu.")
		if type(obj) is np.float64:
			obj = -np.array([obj])
		if type(grad) is np.float64:
			grad = -np.array([grad])

		return obj, grad


def exp_decay(x, decay_coef):
	"""
	squared exponential decay function.
	"""
	return np.exp(-decay_coef * (x ** 2))


def initializeParams(Y, K, singleSigma=False, makePlot=False):
	"""
	initializes parameters using a standard factor analysis model (on imputed data) + exponential curve fitting.
	Checked.
	Input:
	Y: data matrix, n_samples x n_genes
	K: number of latent components
	singleSigma: uses only a single sigma as opposed to a different sigma for every gene
	makePlot: makes a mu - p_0 plot and shows the decaying exponential fit.
	Returns:
	A, mus, sigmas, decay_coef: initialized model parameters.
	"""

	N, D = Y.shape
	model = FactorAnalysis(n_components=K)
	zeroedY = deepcopy(Y)
	mus = np.zeros([D, 1])

	for j in range(D):
		non_zero_idxs = np.abs(Y[:, j]) > 1e-6
		mus[j] = zeroedY[:, j].mean()
		zeroedY[:, j] = zeroedY[:, j] - mus[j]

	model.fit(zeroedY)

	A = model.components_.transpose()
	sigmas = np.atleast_2d(np.sqrt(model.noise_variance_)).transpose()
	if singleSigma:
		sigmas = np.mean(sigmas) * np.ones(sigmas.shape)

	# Now fit decay coefficient
	means = []
	ps = []
	for j in range(D):
		non_zero_idxs = np.abs(Y[:, j]) > 1e-6
		means.append(Y[non_zero_idxs, j].mean())
		ps.append(1 - non_zero_idxs.mean())

	decay_coef, pcov = curve_fit(exp_decay, means, ps, p0=.05)
	decay_coef = decay_coef[0]

	mse = np.mean(np.abs(ps - np.exp(-decay_coef * (np.array(means) ** 2))))

	if (mse > 0) and makePlot:
		from matplotlib.pyplot import figure, scatter, plot, title, show
		figure()
		scatter(means, ps)
		plot(np.arange(min(means), max(means), .1), np.exp(-decay_coef * (np.arange(min(means), max(means), .1) ** 2)))
		title('Decay Coef is %2.3f; MSE is %2.3f' % (decay_coef, mse))
		show()

	return A, mus, sigmas, decay_coef


def testInputData(Y):
	if (Y - np.array(Y, dtype='int32')).sum() < 1e-6:
		raise Exception('Your input matrix is entirely integers. It is possible but unlikely that this is correct: ZIFA takes as input LOG read counts, not read counts.')

	Y_is_zero = np.abs(Y) < 1e-6
	if (Y_is_zero).sum() == 0:
		raise Exception('Your input matrix contains no zeros. This is possible but highly unlikely in scRNA-seq data. ZIFA takes as input log read counts.')

	if (Y < 0).sum() > 0:
		raise Exception('Your input matrix contains negative values. ZIFA takes as input log read counts and should not contain negative values.')

	zero_fracs = Y_is_zero.mean(axis=0)
	column_is_all_zero = zero_fracs == 1.
	if column_is_all_zero.sum() > 0:
		raise Exception("Your Y matrix has columns which are entirely zero; please filter out these columns and rerun the algorithm.")


def fitModel(Y, K, singleSigma=False):
	"""
	fits the model to data.
	Input:
	Y: data matrix, n_samples x n_genes
	K: number of latent components
	singleSigma: if True, fit only a single variance parameter (zero-inflated PPCA) rather than a different one for every gene (zero-inflated factor analysis).
	Returns:
	EZ: the estimated positions in the latent space, n_samples x K
	params: a dictionary of model parameters. Throughout, we refer to lambda as "decay_coef".
	"""
	Y = deepcopy(Y)
	N, D = Y.shape

	if D > 2000:
		print('Warning: this dataset has a large number of genes. If ZIFA takes too long to run, try using block_ZIFA.py instead')

	testInputData(Y)

	print('Running zero-inflated factor analysis with N = %i, D = %i, K = %i' % (N, D, K))

	# Initialize the parameters
	np.random.seed(23)
	A, mus, sigmas, decay_coef = initializeParams(Y, K, singleSigma=singleSigma)
	checkNoNans([A, mus, sigmas, decay_coef])

	max_iter = 100
	param_change_thresh = 1e-2
	n_iter = 0

	while n_iter < max_iter:

		EZ, EZZT, EX, EXZ, EX2 = Estep(Y, A, mus, sigmas, decay_coef)
		new_A, new_mus, new_sigmas, new_decay_coef = Mstep(Y, EZ, EZZT, EX, EXZ, EX2, A, mus, sigmas, decay_coef, singleSigma=singleSigma)
		checkNoNans([EZ, EZZT, EX, EXZ, EX2, new_A, new_mus, new_sigmas, new_decay_coef])

		paramsNotChanging = True
		max_param_change = 0

		for new, old in [[new_A, A], [new_mus, mus], [new_sigmas, sigmas], [new_decay_coef, decay_coef]]:
			rel_param_change = np.mean(np.abs(new - old)) / np.mean(np.abs(new))

			if rel_param_change > max_param_change:
				max_param_change = rel_param_change

			if rel_param_change > param_change_thresh:
				paramsNotChanging = False
				break

		A = new_A
		mus = new_mus
		sigmas = new_sigmas
		decay_coef = new_decay_coef

		if paramsNotChanging:
			print('Param change below threshold %2.3e after %i iterations' % (param_change_thresh, n_iter))
			break

		if n_iter >= max_iter:
			print('Maximum number of iterations reached; terminating loop')
		n_iter += 1

	EZ, EZZT, EX, EXZ, EX2 = Estep(Y, A, mus, sigmas, decay_coef)
	params = {'A': A, 'mus': mus, 'sigmas': sigmas, 'decay_coef': decay_coef}

	return EZ, params