Author(s): Matthew Loper

See LICENCE.txt for licensing and contact information.

import numpy as np
from copy import deepcopy
import scipy.sparse as sp
from .cvwrap import cv2

    from scipy.stats import nanmean as nanmean_impl
    from numpy import nanmean as nanmean_impl

from chumpy.utils import row, col
from .contexts._constants import *

def nanmean(a, axis):
    # don't call nan_to_num in here, unless you check that
    # occlusion_test.py still works after you do it!
    result = nanmean_impl(a, axis=axis)
    return result

def nangradients(arr):
    dy = np.expand_dims(arr[:-1,:,:] - arr[1:,:,:], axis=3)
    dx = np.expand_dims(arr[:,:-1,:] - arr[:, 1:, :], axis=3)

    dy = np.concatenate((dy[1:,:,:], dy[:-1,:,:]), axis=3)
    dy = nanmean(dy, axis=3)
    dx = np.concatenate((dx[:,1:,:], dx[:,:-1,:]), axis=3)
    dx = nanmean(dx, axis=3)

    if arr.shape[2] > 1:
        gy, gx, _ = np.gradient(arr)
        gy, gx = np.gradient(arr.squeeze())
        gy = np.atleast_3d(gy)
        gx = np.atleast_3d(gx)
    gy[1:-1,:,:] = -dy
    gx[:,1:-1,:] = -dx

    return gy, gx

def dImage_wrt_2dVerts_bnd(observed, visible, visibility, barycentric, image_width, image_height, num_verts, f, bnd_bool):
    """Construct a sparse jacobian that relates 2D projected vertex positions
    (in the columns) to pixel values (in the rows). This can be done
    in two steps."""

    n_channels = np.atleast_3d(observed).shape[2]
    shape = visibility.shape

    # Step 1: get the structure ready, ie the IS and the JS
    IS = np.tile(col(visible), (1, 2*f.shape[1])).ravel()
    JS = col(f[visibility.ravel()[visible]].ravel())
    JS = np.hstack((JS*2, JS*2+1)).ravel()

    pxs = np.asarray(visible % shape[1], np.int32)
    pys = np.asarray(np.floor(np.floor(visible) / shape[1]), np.int32)

    if n_channels > 1:
        IS = np.concatenate([IS*n_channels+i for i in range(n_channels)])
        JS = np.concatenate([JS for i in range(n_channels)])

    # Step 2: get the data ready, ie the actual values of the derivatives
    ksize = 1
    bndf = bnd_bool.astype(np.float64)
    nbndf = np.logical_not(bnd_bool).astype(np.float64)
    sobel_normalizer = cv2.Sobel(np.asarray(np.tile(row(np.arange(10)), (10, 1)), np.float64), cv2.CV_64F, dx=1, dy=0, ksize=ksize)[5,5]

    bnd_nan = bndf.reshape((observed.shape[0], observed.shape[1], -1)).copy()
    bnd_nan.ravel()[bnd_nan.ravel()>0] = np.nan
    bnd_nan += 1
    obs_nonbnd = np.atleast_3d(observed) * bnd_nan

    ydiffnb, xdiffnb = nangradients(obs_nonbnd)

    observed = np.atleast_3d(observed)

    if observed.shape[2] > 1:
        ydiffbnd, xdiffbnd, _ = np.gradient(observed)
        ydiffbnd, xdiffbnd = np.gradient(observed.squeeze())
        ydiffbnd = np.atleast_3d(ydiffbnd)
        xdiffbnd = np.atleast_3d(xdiffbnd)

    # This corrects for a bias imposed boundary differences begin spread over two pixels
    # (by np.gradients or similar) but only counted once (since OpenGL's line
    # drawing spans 1 pixel)
    xdiffbnd *= 2.0
    ydiffbnd *= 2.0

    xdiffnb = -xdiffnb
    ydiffnb = -ydiffnb
    xdiffbnd = -xdiffbnd
    ydiffbnd = -ydiffbnd
    # ydiffnb *= 0
    # xdiffnb *= 0

    if False:
        import matplotlib.pyplot as plt
        import pdb; pdb.set_trace()

    idxs = np.isnan(xdiffnb.ravel())
    xdiffnb.ravel()[idxs] = xdiffbnd.ravel()[idxs]

    idxs = np.isnan(ydiffnb.ravel())
    ydiffnb.ravel()[idxs] = ydiffbnd.ravel()[idxs]

    if True: # should be right thing
        xdiff = xdiffnb
        ydiff = ydiffnb
    else:  #should be old way
        xdiff = xdiffbnd
        ydiff = ydiffbnd

    # xdiffnb = -cv2.Sobel(obs_nonbnd, cv2.CV_64F, dx=1, dy=0, ksize=ksize) / np.atleast_3d(cv2.Sobel(row(np.arange(obs_nonbnd.shape[1])).astype(np.float64), cv2.CV_64F, dx=1, dy=0, ksize=ksize))
    # ydiffnb = -cv2.Sobel(obs_nonbnd, cv2.CV_64F, dx=0, dy=1, ksize=ksize) / np.atleast_3d(cv2.Sobel(col(np.arange(obs_nonbnd.shape[0])).astype(np.float64), cv2.CV_64F, dx=0, dy=1, ksize=ksize))
    # xdiffnb.ravel()[np.isnan(xdiffnb.ravel())] = 0.
    # ydiffnb.ravel()[np.isnan(ydiffnb.ravel())] = 0.
    # xdiffnb.ravel()[np.isinf(xdiffnb.ravel())] = 0.
    # ydiffnb.ravel()[np.isinf(ydiffnb.ravel())] = 0.

    # xdiffnb = np.atleast_3d(xdiffnb)
    # ydiffnb = np.atleast_3d(ydiffnb)
    # xdiffbnd = -cv2.Sobel(observed, cv2.CV_64F, dx=1, dy=0, ksize=ksize) / sobel_normalizer
    # ydiffbnd = -cv2.Sobel(observed, cv2.CV_64F, dx=0, dy=1, ksize=ksize) / sobel_normalizer
    # xdiff = xdiffnb * np.atleast_3d(nbndf)
    # xdiff.ravel()[np.isnan(xdiff.ravel())] = 0
    # xdiff += xdiffbnd*np.atleast_3d(bndf)
    # ydiff = ydiffnb * np.atleast_3d(nbndf)
    # ydiff.ravel()[np.isnan(ydiff.ravel())] = 0
    # ydiff += ydiffbnd*np.atleast_3d(bndf)

    #import pdb; pdb.set_trace()

    #xdiff = xdiffnb
    #ydiff = ydiffnb

    #import pdb; pdb.set_trace()

    datas = []

    # The data is weighted according to barycentric coordinates
    bc0 = col(barycentric[pys, pxs, 0])
    bc1 = col(barycentric[pys, pxs, 1])
    bc2 = col(barycentric[pys, pxs, 2])
    for k in range(n_channels):
        dxs = xdiff[pys, pxs, k]
        dys = ydiff[pys, pxs, k]
        if f.shape[1] == 3:

    data = np.concatenate(datas)

    ij = np.vstack((IS.ravel(), JS.ravel()))
    result = sp.csc_matrix((data, ij), shape=(image_width*image_height*n_channels, num_verts*2))

    return result

def dImage_wrt_2dVerts(observed, visible, visibility, barycentric, image_width, image_height, num_verts, f):
    """Construct a sparse jacobian that relates 2D projected vertex positions
    (in the columns) to pixel values (in the rows). This can be done
    in two steps."""

    n_channels = np.atleast_3d(observed).shape[2]
    shape = visibility.shape

    # Step 1: get the structure ready, ie the IS and the JS
    IS = np.tile(col(visible), (1, 2*f.shape[1])).ravel()
    JS = col(f[visibility.ravel()[visible]].ravel())
    JS = np.hstack((JS*2, JS*2+1)).ravel()

    pxs = np.asarray(visible % shape[1], np.int32)
    pys = np.asarray(np.floor(np.floor(visible) / shape[1]), np.int32)

    if n_channels > 1:
        IS = np.concatenate([IS*n_channels+i for i in range(n_channels)])
        JS = np.concatenate([JS for i in range(n_channels)])

    # Step 2: get the data ready, ie the actual values of the derivatives
    sobel_normalizer = cv2.Sobel(np.asarray(np.tile(row(np.arange(10)), (10, 1)), np.float64), cv2.CV_64F, dx=1, dy=0, ksize=ksize)[5,5]
    xdiff = -cv2.Sobel(observed, cv2.CV_64F, dx=1, dy=0, ksize=ksize) / sobel_normalizer
    ydiff = -cv2.Sobel(observed, cv2.CV_64F, dx=0, dy=1, ksize=ksize) / sobel_normalizer

    xdiff = np.atleast_3d(xdiff)
    ydiff = np.atleast_3d(ydiff)

    datas = []

    # The data is weighted according to barycentric coordinates
    bc0 = col(barycentric[pys, pxs, 0])
    bc1 = col(barycentric[pys, pxs, 1])
    bc2 = col(barycentric[pys, pxs, 2])
    for k in range(n_channels):
        dxs = xdiff[pys, pxs, k]
        dys = ydiff[pys, pxs, k]
        if f.shape[1] == 3:

    data = np.concatenate(datas)

    ij = np.vstack((IS.ravel(), JS.ravel()))
    result = sp.csc_matrix((data, ij), shape=(image_width*image_height*n_channels, num_verts*2))

    return result

def flow_to(self, v_next, cam_next):
    from chumpy.ch import MatVecMult

    color_image = self.r
    visibility = self.visibility_image
    pxpos = np.zeros_like(self.color_image)
    pxpos[:,:,0] = np.tile(row(np.arange(self.color_image.shape[1])), (self.color_image.shape[0], 1))
    pxpos[:,:,2] = np.tile(col(np.arange(self.color_image.shape[0])), (1, self.color_image.shape[1]))

    visible = np.nonzero(visibility.ravel() != 4294967295)[0]
    num_visible = len(visible)

    barycentric = self.barycentric_image

    # map 3d to 3d
    JS = col(self.f[visibility.ravel()[visible]]).ravel()
    IS = np.tile(col(np.arange(JS.size/3)), (1, 3)).ravel()
    data = barycentric.reshape((-1,3))[visible].ravel()

    # replicate to xyz
    IS = np.concatenate((IS*3, IS*3+1, IS*3+2))
    JS = np.concatenate((JS*3, JS*3+1, JS*3+2))
    data = np.concatenate((data, data, data))

    verts_to_visible = sp.csc_matrix((data, (IS, JS)), shape=(np.max(IS)+1, self.v.r.size))

    v_old = self.camera.v
    cam_old = self.camera

    if cam_next is None:
        cam_next = self.camera

    self.camera.v = MatVecMult(verts_to_visible, self.v.r)
    r1 = self.camera.r.copy()

    self.camera = cam_next
    self.camera.v = MatVecMult(verts_to_visible, v_next)
    r2 = self.camera.r.copy()

    n_channels = self.camera.shape[1]
    flow = r2 - r1
    flow_im = np.zeros((self.frustum['height'], self.frustum['width'], n_channels)).reshape((-1,n_channels))

    flow_im[visible] = flow
    flow_im = flow_im.reshape((self.frustum['height'], self.frustum['width'], n_channels))

    self.camera = cam_old
    self.camera.v = v_old
    return flow_im

def dr_wrt_bgcolor(visibility, frustum, num_channels):
    invisible = np.nonzero(visibility.ravel() == 4294967295)[0]
    IS = invisible
    JS = np.zeros(len(IS))
    data = np.ones(len(IS))

    # color image, so 3 channels
    IS = np.concatenate([IS*num_channels+k for k in range(num_channels)])
    JS = np.concatenate([JS*num_channels+k for k in range(num_channels)])
    data = np.concatenate([data for i in range(num_channels)])
    # IS = np.concatenate((IS*3, IS*3+1, IS*3+2))
    # JS = np.concatenate((JS*3, JS*3+1, JS*3+2))
    # data = np.concatenate((data, data, data))

    ij = np.vstack((IS.ravel(), JS.ravel()))
    result = sp.csc_matrix((data, ij), shape=(frustum['width']*frustum['height']*num_channels, num_channels))
    return result

def dr_wrt_vc(visible, visibility, f, barycentric, frustum, vc_size, num_channels):
    # Each pixel relies on three verts
    IS = np.tile(col(visible), (1, 3)).ravel()
    JS = col(f[visibility.ravel()[visible]].ravel())

    bc = barycentric.reshape((-1,3))
    data = np.asarray(bc[visible,:], order='C').ravel()

    IS = np.concatenate([IS*num_channels+k for k in range(num_channels)])
    JS = np.concatenate([JS*num_channels+k for k in range(num_channels)])
    data = np.concatenate([data for i in range(num_channels)])
    # IS = np.concatenate((IS*3, IS*3+1, IS*3+2))
    # JS = np.concatenate((JS*3, JS*3+1, JS*3+2))
    # data = np.concatenate((data, data, data))

    ij = np.vstack((IS.ravel(), JS.ravel()))
    result = sp.csc_matrix((data, ij), shape=(frustum['width']*frustum['height']*num_channels, vc_size))
    return result

def draw_visibility_image(gl, v, f, boundarybool_image=None):
    v = np.asarray(v)

    result = draw_visibility_image_internal(gl, v, f)
    if boundarybool_image is None:
        return result

    rr = result.ravel()
    faces_to_draw = np.unique(rr[rr != 4294967295])
    if len(faces_to_draw)==0:
        result = np.ones((gl.height, gl.width)).astype(np.uint32)*4294967295
        return result
    gl.PolygonMode(GL_FRONT_AND_BACK, GL_LINE)
    result2 = draw_visibility_image_internal(gl, v, f[faces_to_draw])
    gl.PolygonMode(GL_FRONT_AND_BACK, GL_FILL)
    bbi = boundarybool_image

    result2 = result2.ravel()
    idxs = result2 != 4294967295
    result2[idxs] = faces_to_draw[result2[idxs]]

    if False:
        result2[result2==4294967295] = 0
        import matplotlib.pyplot as plt
        result2 = result2.reshape(result.shape[:2])

    result2 = result2.reshape(result.shape[:2])
    result = result2 * bbi + result * (1 - bbi)
    return result

def draw_visibility_image_internal(gl, v, f):
    """Assumes camera is set up correctly in gl context."""

    fc = np.arange(1, len(f)+1)
    fc = np.tile(col(fc), (1, 3))
    fc[:, 0] = fc[:, 0] & 255
    fc[:, 1] = (fc[:, 1] >> 8 ) & 255
    fc[:, 2] = (fc[:, 2] >> 16 ) & 255
    fc = np.asarray(fc, dtype=np.uint8)

    draw_colored_primitives(gl, v, f, fc)
    raw = np.asarray(gl.getImage(), np.uint32)
    raw = raw[:,:,0] + raw[:,:,1]*256 + raw[:,:,2]*256*256 - 1
    return raw

# this assumes that fc is either "by faces" or "verts by face", not "by verts"
def draw_colored_primitives(gl, v, f, fc=None):

    verts_by_face = np.asarray(v.reshape((-1,3))[f.ravel()], dtype=np.float64, order='C')

    if fc is not None:
        if fc.size == verts_by_face.size:
            vc_by_face = fc
            vc_by_face = np.repeat(fc, f.shape[1], axis=0)

        if vc_by_face.size != verts_by_face.size:
            raise Exception('fc must have either rows=(#rows in faces) or rows=(# elements in faces)')

        if isinstance(fc[0,0], np.float64):
            vc_by_face = np.asarray(vc_by_face, dtype=np.float64, order='C')
        elif isinstance(fc[0,0], np.uint8):
            vc_by_face = np.asarray(vc_by_face, dtype=np.uint8, order='C')
            raise Exception('Unknown color type for fc')

    if f.shape[1]==2:
        primtype = GL_LINES
        primtype = GL_TRIANGLES
    gl.DrawElements(primtype, np.arange(f.size, dtype=np.uint32).ravel())

    if primtype == GL_LINES:
        f = np.fliplr(f).copy()
        verts_by_edge = v.reshape((-1,3))[f.ravel()]
        verts_by_edge = np.asarray(verts_by_edge, dtype=np.float64, order='C')
        gl.DrawElements(GL_LINES, np.arange(f.size, dtype=np.uint32).ravel())

def draw_texcoord_image(glf, v, f, vt, ft, boundarybool_image=None):
    gl = glf


    # want vtc: texture-coordinates per vertex (not per element in vc)
    colors = vt[ft.ravel()]

    colors = np.asarray(np.hstack((colors, col(colors[:,0]*0))), np.float64, order='C')
    draw_colored_primitives(gl, v, f, colors)

    if boundarybool_image is not None:
        gl.PolygonMode(GL_FRONT_AND_BACK, GL_LINE)
        draw_colored_primitives(gl, v, f, colors)
        gl.PolygonMode(GL_FRONT_AND_BACK, GL_FILL)

    result = np.asarray(deepcopy(gl.getImage()), np.float64, order='C')[:,:,:2].copy()
    result[:,:,1] = 1. - result[:,:,1]
    return result

def draw_barycentric_image(gl, v, f, boundarybool_image=None):
    v = np.asarray(v)
    without_overdraw = draw_barycentric_image_internal(gl, v, f)
    if boundarybool_image is None:
        return without_overdraw

    gl.PolygonMode(GL_FRONT_AND_BACK, GL_LINE)
    overdraw = draw_barycentric_image_internal(gl, v, f)
    gl.PolygonMode(GL_FRONT_AND_BACK, GL_FILL)

    bbi = np.atleast_3d(boundarybool_image)
    return bbi * overdraw + (1. - bbi) * without_overdraw

def draw_barycentric_image_internal(gl, v, f):


    verts_by_face = v.reshape((-1,3))[f.ravel()]
    verts_by_face = np.asarray(verts_by_face, dtype=np.float64, order='C')
    vc_by_face = np.asarray(np.tile(np.eye(3)[:f.shape[1], :], (verts_by_face.shape[0]/f.shape[1], 1)), order='C')

    gl.DrawElements(GL_TRIANGLES if f.shape[1]==3 else GL_LINES, np.arange(f.size, dtype=np.uint32).ravel())
    result = np.asarray(deepcopy(gl.getImage()), np.float64)

    return result

# May end up using this, maybe not
def get_inbetween_boundaries(self):
    camera = self.camera
    frustum = self.frustum
    w = frustum['width']
    h = frustum['height']
    far = frustum['far']
    near = frustum['near']

    self.glb.Viewport(0, 0, w-1, h)
                  camera.c.r[0]-.5, camera.c.r[1],
                  camera.f.r[0], camera.f.r[1],
                  w-1, h,
                  near, far,
                  camera.view_matrix, camera.k)
    bnd_x = draw_boundaryid_image(self.glb, self.v.r, self.f, self.vpe, self.fpe, self.camera)[:,:-1]

    self.glb.Viewport(0, 0, w, h-1)
                  camera.c.r[0], camera.c.r[1]-.5,
                  camera.f.r[0], camera.f.r[1],
                  w, h-1,
                  near, far,
                  camera.view_matrix, camera.k)
    bnd_y = draw_boundaryid_image(self.glb, self.v.r, self.f, self.vpe, self.fpe, self.camera)[:-1,:]

    # Put things back to normal
    self.glb.Viewport(0, 0, w, h)
    setup_camera(self.glb, camera, frustum)
    return bnd_x, bnd_y