import os
import numpy as np
import tensorflow as tf
import matplotlib.pylab as plt

FLAGS = tf.app.flags.FLAGS


def save_image(data, data_format, e, suffix=None):
    """Saves a picture showing the current progress of the model"""

    X_G, X_real = data

    Xg = X_G[:8]
    Xr = X_real[:8]

    if data_format == "NHWC":
        X = np.concatenate((Xg, Xr), axis=0)
        list_rows = []
        for i in range(int(X.shape[0] / 4)):
            Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=1)
            list_rows.append(Xr)

        Xr = np.concatenate(list_rows, axis=0)

    if data_format == "NCHW":
        X = np.concatenate((Xg, Xr), axis=0)
        list_rows = []
        for i in range(int(X.shape[0] / 4)):
            Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=2)
            list_rows.append(Xr)

        Xr = np.concatenate(list_rows, axis=1)
        Xr = Xr.transpose(1,2,0)

    if Xr.shape[-1] == 1:
        plt.imshow(Xr[:, :, 0], cmap="gray")
    else:
        plt.imshow(Xr)
    plt.axis("off")
    if suffix is None:
        plt.savefig(os.path.join(FLAGS.fig_dir, "current_batch_%s.png" % e))
    else:
        plt.savefig(os.path.join(FLAGS.fig_dir, "current_batch_%s_%s.png" % (suffix, e)))
    plt.clf()
    plt.close()


def get_stacked_tensor(X1, X2):

    X = tf.concat((X1[:16], X2[:16]), axis=0)
    list_rows = []
    for i in range(8):
        Xr = tf.concat([X[k] for k in range(4 * i, 4 * (i + 1))], axis=2)
        list_rows.append(Xr)

    X = tf.concat(list_rows, axis=1)
    X = tf.transpose(X, (1,2,0))
    X = tf.expand_dims(X, 0)

    return X