"""
Transferring labels from a segmented atlas
=============================================

We use a new multiscale algorithm for solving regularized Optimal Transport 
problems on the GPU, with a linear memory footprint. 

We use the resulting smooth assignments to perform label transfer for atlas-based 
segmentation of fiber tractograms. The parameters -- \emph{blur} and \emph{reach} -- 
of our method are meaningful, defining the minimum and maximum distance at which 
two fibers are compared with each other. They can be set according to anatomical knowledge.
"""


##############################################
# Setup
# ---------------------
#
# Standard imports:

import numpy as np
import matplotlib.pyplot as plt
import time
import torch
from geomloss import SamplesLoss
use_cuda = torch.cuda.is_available()
dtype    = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
dtypeint = torch.cuda.LongTensor  if use_cuda else torch.LongTensor

###############################################
# Loading and saving data routines
#

from tract_io import read_vtk, streamlines_resample, save_vtk, save_vtk_labels
from tract_io import save_tract, save_tract_numpy
from tract_io import save_tract_with_labels, save_tracts_labels_separate

##############################################
# Dataset
# ---------------------
#
# Fetch data from the KeOps website:

import os

def fetch_file(name):
    if not os.path.exists(f'data/{name}.npy'):
        import urllib.request
        print("Fetching the atlas... ", end="", flush=True)
        urllib.request.urlretrieve(
            f'https://www.kernel-operations.io/data/{name}.npy', 
            f'data/{name}.npy')
        print("Done.")

fetch_file("tracto_atlas")
fetch_file("atlas_labels")
fetch_file("tracto1")


##############################################
# Fibers do not have a canonical orientation. Since our ground distance is a simple
# L2-distance on the sampled fibers, we augment the dataset with the mirror flip 
# of all fibers and perform the OT on this augmented dataset.

def torch_load(X, dtype=dtype):
    return torch.from_numpy(X).type(dtype).contiguous()


def add_flips(X):
    """Adds flips and loads on the GPU the input fiber track."""
#    X = X[:,None,:,:]
    X_flip = torch.flip( X, (1,) )
    X = torch.stack((X, X_flip), dim=1)  # (Nfibers, 2, NPOINTS, 3)
    return X

###############################################################################
# Source atlas 
# ~~~~~~~~~~~~~~~~~~~
#
# Load atlas (segmented, each fiber has a label):

Y_j   = torch_load( np.load("data/tracto_atlas.npy") )
labels_j = torch_load( np.load("data/atlas_labels.npy"), dtype=dtypeint )

###############################################################################
#

M, NPOINTS = Y_j.shape[0], Y_j.shape[1]  # Number of fibers, points per fiber

###############################################################################
#

Y_j = Y_j.view(M, NPOINTS, 3) / np.sqrt(NPOINTS)

###############################################################################
#

Y_j = add_flips(Y_j)  # Shape (M, 2, NPOINTS, 3)
##############################################
# Target subject
# ~~~~~~~~~~~~~~~~~~~~
#
# Load a new subject (unlabelled)
#

X_i = torch_load( np.load("data/tracto1.npy") )
N, NPOINTS_i = X_i.shape[0], X_i.shape[1]   # Number of fibers, points per fiber

if NPOINTS != NPOINTS_i:
    raise ValueError("The atlas and the subject are not sampled with the same number of points: "
                    +f"{NPOINTS} and {NPOINTS_i}, respectively.")

X_i = X_i.view(N, NPOINTS, 3) / np.sqrt(NPOINTS)
X_i = add_flips(X_i)  # Shape (N, 2, NPOINTS, 3)


##############################################
# Feature engineering 
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Add some weight on both ends of our fibers:
#

gamma = 2.
X_i[:,:,0,:] *= gamma ;  X_i[:,:,-1,:] *= gamma
Y_j[:,:,0,:] *= gamma ;  Y_j[:,:,-1,:] *= gamma

###############################################################################
# Optimizing performances
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Contiguous memory accesses are critical for performances on the GPU.
# 

from pykeops.torch.cluster import sort_clusters, cluster_ranges

ranges_j   = cluster_ranges(labels_j)     # Ranges for all clusters
Y_j, labels_j = sort_clusters(Y_j, labels_j)  # Make sure that all clusters are contiguous in memory

C = len(ranges_j)  # Number of classes

if C != labels_j.max() + 1:
    raise ValueError("???")

for j, (start_j, end_j) in enumerate(ranges_j):
    if start_j >= end_j:
        raise ValueError(f"The {j}-th cluster of the atlas seems to be empty.")

###############################################################################
# Each fiber is sampled with 20 points in R^3. 
# Thus, one tractogram is a matrix of size n x 60 where n is the number of fibers
# The atlas is labelled, wich means that each fiber belong to a cluster. 
# This is summarized by the vector labels_j of size n x 1. labels_j[i] is the label of the fiber i. 
# Subsample the data by a factor 4 if you want to reduce the computational time:

subsample = 20 if True else 1    


##############################################
# 

to_keep = []
for start_j, end_j in ranges_j:
    to_keep += list(range(start_j, end_j, subsample))

Y_j, labels_j = Y_j[to_keep].contiguous(), labels_j[to_keep].contiguous()
ranges_j = cluster_ranges(labels_j)  # Keep the ranges up to date!

##############################################
# 

X_i = X_i[::subsample].contiguous()


##############################################
# 

N, M = len(X_i), len(Y_j)

print("Data loaded.")


##############################################
# Pre-computing cluster prototypes     
# --------------------------------------
#

from pykeops.torch import LazyTensor

def nn_search(x_i, y_j, ranges = None):
    x_i = LazyTensor( x_i[:,None,:] )  # Broadcasted "line" variable
    y_j = LazyTensor( y_j[None,:,:] )  # Broadcasted "column" variable

    D_ij = ((x_i - y_j) ** 2).sum(-1)  # Symbolic matrix of squared distances
    D_ij.ranges = ranges  # Apply our block-sparsity pattern

    return D_ij.argmin(dim=1).view(-1)


################################################################################
# K-Means loop:
#

def KMeans(x_i, c_j, Nits = 10, ranges = None):

    D = x_i.shape[1]
    for i in range(10):
        # Points -> Nearest cluster
        labs_i = nn_search(x_i, c_j, ranges = ranges)
        # Class cardinals:
        Ncl = torch.bincount(labs_i.view(-1)).type(dtype)
        # Compute the cluster centroids with torch.bincount:
        for d in range(D):  # Unfortunately, vector weights are not supported...
            c_j[:, d] = torch.bincount(labs_i.view(-1), weights=x_i[:, d]) / Ncl
    
    return c_j, labs_i


##############################################
# On the subject
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# For new subject (unlabelled), we perform a simple Kmean
# on R^60 to obtain a cluster of the data.
#

K = 1000

# Pick K fibers at random:
perm = torch.randperm(N)
random_labels = perm[:K]
C_i = X_i[random_labels] # (K, 2, NPOINTS, 3)

# Reshape our data as "N-by-60" tensors:
C_i_flat = C_i.view(K * 2, NPOINTS * 3)  # Flattened list of centroids
X_i_flat = X_i.view(N * 2, NPOINTS * 3)  # Flattened list of fibers

# Retrieve our new centroids:
C_i_flat, labs_i = KMeans(X_i_flat, C_i_flat)
C_i = C_i_flat.view(K, 2, NPOINTS, 3)
# Standard deviation of our clusters:
std_i = (( X_i_flat - C_i_flat[labs_i.view(-1),:] ) ** 2).sum(dim = 1).mean().sqrt()


############################################################################################
#
# On the atlas
# ~~~~~~~~~~~~~~~~~~~~~~~
#
# To use the multiscale version of the regularized OT, 
# we need to have a cluster of our input data (atlas and new subject).
# For the atlas, the cluster is given by the segmentation. We use a Kmeans to 
# separate the fibers and the flips within a cluser, in order to have clusters whose fibers have similar
# orientation
#
ranges_yi = 2 * ranges_j

ranges_cj = 2 * torch.arange(C).type_as(ranges_j)
ranges_cj = torch.stack((ranges_cj, ranges_cj + 2)).t().contiguous()

slices_i = 1 + torch.arange(C).type_as(ranges_j)
ranges_yi_cj = (ranges_yi, slices_i, ranges_cj, ranges_cj, slices_i, ranges_yi)



################################################################################
# Pick one unoriented (i.e. two oriented) fibers per class:

first_labels = ranges_j[:,0]  # One label per class

C_j      = Y_j[first_labels.type(dtypeint),:,:,:]             # (C, 2, NPOINTS, 3)
C_j_flat = C_j.view(C * 2, NPOINTS * 3)  # Flattened list of centroids

############################################################################################
#


Y_j_flat = Y_j.view(M * 2, NPOINTS * 3)
C_j_flat, labs_j = KMeans(Y_j_flat, C_j_flat, ranges = ranges_yi_cj)
C_j = C_j_flat.view(C, 2, NPOINTS, 3)

std_j = (( Y_j_flat - C_j_flat[labs_j.view(-1),:] ) ** 2).sum(dim = 1).mean().sqrt()



########################################################
# Compute the OT plan with the multiscale algorithm    
# ------------------------------------------------------
#
# To use the **multiscale** Sinkhorn algorithm,
# we should simply provide:
#
# - explicit **labels** and **weights** for both input measures,
# - a typical **cluster_scale** which specifies the iteration at which
#   the Sinkhorn loop jumps from a **coarse** to a **fine** representation
#   of the data.
#
blur = 3.
OT_solver =  SamplesLoss("sinkhorn", p=2, blur=blur, reach=20,  
                         scaling=.9, cluster_scale = max(std_i,std_j), 
                         debias=False, potentials=True, verbose=True) 

############################################################################################
# To specify explicit cluster labels, SamplesLoss also requires
# explicit weights. Let's go with the default option - a uniform distribution:

a_i = torch.ones(2 * N).type(dtype) / (2 * N)
b_j = torch.ones(2 * M).type(dtype) / (2 * M)

start = time.time()

# Compute the dual vectors F_i and G_j:
# 6 args -> labels_i, weights_i, locations_i, labels_j, weights_j, locations_j
F_i, G_j = OT_solver(labs_i, a_i, X_i.view(N * 2, NPOINTS * 3), 
                labs_j, b_j, Y_j.view(M * 2, NPOINTS * 3)) 

if use_cuda: torch.cuda.synchronize()
end = time.time()

print("OT computed in  in {:.3f}s.".format(end-start))


##############################################
# Use the OT to perform the transfer of labels
# ----------------------------------------------
# 
# The transport plan pi_{i,j} gives the probability for 
# a fiber i of the subject to be assigned to the (labeled) fiber j of the atlas.
# We assign a label l to the fiber i as the label with maximum probability for all the soft assignement of i. 

# Return to the original data (unflipped)
X_i = X_i[:,0,:,:].contiguous()  # (N, NPOINTS, 3)
F_i = F_i[::2].contiguous()      # (N,)


################################################
# Compute the transport plan:
#

XX_i = LazyTensor( X_i.view( N,     1, NPOINTS * 3) )
YY_j = LazyTensor( Y_j.view( 1, M * 2, NPOINTS * 3) )
FF_i = LazyTensor( F_i.view( N,     1,      1     ) )
GG_j = LazyTensor( G_j.view( 1, M * 2,      1     ) )

# Cost matrix:
CC_ij = ((XX_i - YY_j) ** 2).sum(-1) / 2  # (N, M * 2, 1) LazyTensor

# Scaled kernel matrix:
KK_ij = (( FF_i + GG_j - CC_ij ) / blur**2 ).exp()  # (N, M * 2, 1) LazyTensor


################################################
# Transfer the labels, bypassing the one-hot vector encoding
# for the sake of efficiency:


def slicing_ranges(start, end):
    """KeOps does not yet support sliced indexing of LazyTensors, so we have to resort to some black magic..."""
    ranges_i    = torch.Tensor([[0, N]]      ).type(dtypeint).int()  # Int32, on the correct device
    slices_i    = torch.Tensor( [1]          ).type(dtypeint).int()
    redranges_j = torch.Tensor([[start, end]]).type(dtypeint).int()
    return (ranges_i, slices_i, redranges_j, redranges_j, slices_i, ranges_i)


weights_i = torch.zeros(C + 1, N).type(torch.FloatTensor)  # C classes + outliers

for c in range(C):
    start, end = 2 * ranges_j[c]
    KK_ij.ranges = slicing_ranges(start, end)  # equivalent to "PP_ij[:, start:end]", which is not supported yet...
    weights_i[c] = (KK_ij.sum(dim=1).view(N) / (2 * M)).cpu()

weights_i[C] = 0.2  # If no label has a bigger weight than .01, this fiber is an outlier

labels_i = weights_i.argmax(dim=0)  # (N,) vector



################################################
# Save our new cluster information as a signal:

# Come back to the original data
X_i[:,0,:] /= gamma ;  X_i[:,-1,:] /= gamma

save_tracts_labels_separate('output/labels_subject', X_i, labels_i, 0, labels_i.max() + 1) #save the data