from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np

import torch
from random import gauss, uniform


def make_rand_vector(dims):
    vec = [np.random.normal(0, 0.12) for i in range(dims)]

    return vec


def rand_vec(dim, z):
    vec = [uniform(0, 1) for i in range(dim - 1)]
    mag = 1 - z ** 2
    y = [x / mag for x in vec] + [z]
    z = sum([x ** 2 for x in y])
    print(z)


def draw_ball(z=np.linspace(-1, 1, 50)):
    bag = []
    for x in z:
        r = np.sqrt(1 - x ** 2)
        theta = np.linspace(-180, 180, 20)
        for t in theta:
            ys = r * np.sin(np.pi * t / 180.)
            xs = r * np.cos(np.pi * t / 180.)
            bag.append([x, xs, ys])
    return bag


def drawSphere(xCenter, yCenter, zCenter, r):
    # draw sphere
    u, v = np.mgrid[0:2 * np.pi:20j, 0:np.pi:10j]
    x = np.cos(u) * np.sin(v)
    y = np.sin(u) * np.sin(v)
    z = np.cos(v)
    # shift and scale sphere
    x = r * x + xCenter
    y = r * y + yCenter
    z = r * z + zCenter
    return (x, y, z)


bag = []
# bag.append(draw_ball())
from NVLL.util.util import GVar

for n in range(5):
    x = make_rand_vector(3)
    x = np.asarray(x)
    tmp = []
    for _ in range(20):
        y = np.random.normal(x, 0.1)
        print(y)
        tmp.append(y)
    bag.append(tmp)

fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection='3d')

(xs, ys, zs) = drawSphere(0, 0, 0, 1)
ax.plot_wireframe(xs, ys, zs, color="black", linestyle=":", lw=0.75)
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d


class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        FancyArrowPatch.draw(self, renderer)


a = Arrow3D([0, 0], [0, 0],
            [-1.2, 1.2], mutation_scale=7,
            lw=1, arrowstyle="-|>", color="black")
b = Arrow3D([-1.2, 1.2], [0, 0], [0, 0],
            mutation_scale=7,
            lw=1, arrowstyle="-|>", color="black")
c = Arrow3D([0, 0], [-1.2, 1.2], [0, 0],
            mutation_scale=7,
            lw=1, arrowstyle="-|>", color="black")
ax.add_artist(a)
ax.add_artist(b)
ax.add_artist(c)

# For each set of style and range settings, plot n random points in the box
# defined by x in [23, 32], y in [0, 100], z in [zlow, zhigh].
bank_color = ['b', 'r', 'g', 'c', 'm']
bank_marker = ['.', 'o', 'v', '^', '<']
for idx, group in enumerate(bag):
    c = bank_color[idx]
    m = '.'
    mean_x = 0
    mean_y = 0
    mean_z = 0
    for point in group:
        xs, ys, zs = point
        mean_x += xs

        mean_y += ys
        mean_z += zs
        ax.scatter(xs, ys, zs, c=c, marker=m)
    mean_x /= len(group)
    mean_y /= len(group)
    mean_z /= len(group)
    line = Arrow3D([0, mean_x], [0, mean_y], [0, mean_z],
                   mutation_scale=10,
                   lw=1.5, arrowstyle="-|>", color=c)

    ax.add_artist(line)

# for c, m, zlow, zhigh in [('r', 'o', -50, -25), ('b', '^', -30, -5)]:
#     xs = randrange(n, 23, 32)
#     ys = randrange(n, 0, 100)
#     zs = randrange(n, zlow, zhigh)
#     ax.scatter(xs, ys, zs, c=c, marker=m)
# ax.set_axis_on()
ax.set_axis_off()

start, end = ax.get_xlim()
ax.xaxis.set_ticks(np.arange(-1, 1.5, 0.5))
ax.yaxis.set_ticks(np.arange(-1, 1.5, 0.5))
ax.zaxis.set_ticks(np.arange(-1, 1.5, 0.5))

# make the panes transparent
ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# make the grid lines transparent
ax.xaxis._axinfo["grid"]['color'] = (1, 1, 1, 0)
ax.yaxis._axinfo["grid"]['color'] = (1, 1, 1, 0)
ax.zaxis._axinfo["grid"]['color'] = (1, 1, 1, 0)

# ax.set_xlabel('X')
# ax.set_ylabel('Y')
# ax.set_zlabel('Z')
# fig.set_size_inches(5, 5)
fig.savefig('gauss.pdf', transparent=True)
plt.show()