from matplotlib import pyplot as plt
from matplotlib import cm
import numpy as np



def plotSubFigure(X, Y, Z, subfig, type_):
    fig = plt.gcf()
    ax = fig.add_subplot(1, 3, subfig, projection='3d')
    #ax = fig.gca(projection='3d')
    if type_ == "colormap":
        ax.plot_surface(X, Y, Z, cmap=cm.viridis, rstride=1, cstride=1,
                        shade=True, linewidth=0, antialiased=False)
    else:
        ax.plot_surface(X, Y, Z, color=[0.7, 0.7, 0.7], rstride=1, cstride=1,
                        shade=True, linewidth=0, antialiased=False)

    ax.set_aspect("equal")

    max_range = np.array([X.max()-X.min(), Y.max()-Y.min(), Z.max()-Z.min()]).max() / 2.0 * 0.6
    mid_x = (X.max()+X.min()) * 0.5
    mid_y = (Y.max()+Y.min()) * 0.5
    mid_z = (Z.max()+Z.min()) * 0.5
    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)

    az, el = 90, 90
    if type_ == "top":
        az = 130
    elif type_ == "side":
        az, el = 40, 0

    ax.view_init(az, el)
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)

    plt.grid(False)
    plt.axis('off')


def plotDepth(Z):
    x = np.linspace(0, Z.shape[0]-1, Z.shape[0])
    y = np.linspace(0, Z.shape[1]-1, Z.shape[1])
    X, Y = np.meshgrid(x, y)

    fig = plt.figure(figsize=(12, 6))
    plotSubFigure(X, Y, Z, 1, "colormap")
    plotSubFigure(X, Y, Z, 2, "top")
    plotSubFigure(X, Y, Z, 3, "side")

    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)