import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
sys.path.append(os.path.join(ROOT_DIR, 'tf_ops/nn_distance'))
sys.path.append(os.path.join(ROOT_DIR, 'tf_ops/sampling'))
sys.path.append(os.path.join(ROOT_DIR, 'tf_ops/grouping'))
import tensorflow as tf
import numpy as np
import tf_util
import tf_nndistance
from tf_sampling import farthest_point_sample, gather_point
from tf_grouping import query_ball_point, group_point
from pointnet_util import pointnet_fp_module, pointnet_sa_module
import config

def placeholder_inputs(config):
    pc_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_POINT, 3))
    color_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_POINT, 3))
    pc_ins_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_GROUP, config.NUM_POINT_INS, 3))
    group_label_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_POINT))
    group_indicator_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_GROUP))
    seg_label_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_POINT))
    bbox_ins_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_GROUP, 6))
    return pc_pl, color_pl, pc_ins_pl, group_label_pl, group_indicator_pl, seg_label_pl, bbox_ins_pl

def multi_encoding_net(xyz, points, npoint, radius_list, nsample_list, mlp_list, mlp_list2, is_training, bn_decay, scope, bn=True, use_xyz=False, output_shift=False, shift_pred=None, fps_idx=None):
    ''' Encode multiple context.
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint: int32 -- #points sampled in farthest point sampling
            radius: list of float32 -- search radius in local region
            nsample: list of int32 -- how many points in each local region
            mlp: list of list of int32 -- output size for MLP on each point
            use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
        Return:
            new_xyz: (batch_size, npoint, 3) TF tensor
            new_points: (batch_size, npoint, mlp2[-1]) TF tensor
            shift_pred: (batch_size, npoint, 3) TF tensor
            fps_idx: (batch_size, npoint) TF tensor
    '''
    with tf.variable_scope(scope) as sc:
        if fps_idx is None:
            fps_idx = farthest_point_sample(npoint, xyz) # (batch_size, npoint)
        new_xyz = gather_point(xyz, fps_idx) # (batch_size, npoint, 3)
        new_points_list = []
        group_xyz_list = []
        for i in range(len(radius_list)):
            radius = radius_list[i]
            nsample = nsample_list[i]
            idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz)
            grouped_xyz = group_point(xyz, idx)
            grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1]) # [B, nseed, nsmp, 3]
            if shift_pred is not None:
                grouped_xyz -= tf.tile(tf.expand_dims(shift_pred, 2), [1,1,nsample,1])
            if points is not None:
                grouped_points = group_point(points, idx)
                if use_xyz:
                    grouped_points = tf.concat([grouped_points, grouped_xyz], axis=-1)
            else:
                grouped_points = grouped_xyz
            for j,num_out_channel in enumerate(mlp_list[i]):
                grouped_points = tf_util.conv2d(grouped_points, num_out_channel, [1, 1],
                                                padding='VALID', stride=[1,1], bn=bn, is_training=is_training,
                                                scope='conv_prev_%d_%d'%(i,j), bn_decay=bn_decay)
            new_points_list.append(tf.reduce_max(grouped_points, axis=[2]))
        new_points = tf.concat(new_points_list, axis=-1) # (batch_size, npoint, \sum_k{mlp[k][-1]})
        for i,num_out_channel in enumerate(mlp_list2):
            new_points = tf_util.conv1d(new_points, num_out_channel, 1,
                                               padding='VALID', stride=1, bn=bn, is_training=is_training,
                                               scope='conv_post_%d'%i, bn_decay=bn_decay)
        if output_shift:
            shift_pred = tf_util.conv1d(new_points, 4, 1,
                                        padding='VALID', stride=1, scope='conv_shift_pred', activation_fn=None)
        return new_xyz, new_points, shift_pred, fps_idx

def shift_pred_net(xyz, points, npoint_seed, end_points, scope, is_training, bn_decay=None, return_fullfea=False):
    ''' Encode multiple context.
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
        Return:
            pc_seed: (batch_size, npoint_seed, 3) TF tensor
            shift_pred_seed_4d: (batch_size, npoint_seed, 4) TF tensor
            ind_seed: (batch_size, npoint_seed) TF tensor
    '''
    with tf.variable_scope(scope) as sc:
        ind_seed = farthest_point_sample(npoint_seed, xyz) # (batch_size, npoint_seed)
        pc_seed = gather_point(xyz, ind_seed) # (batch_size, npoint_seed, 3)
        batch_size = xyz.get_shape()[0].value
        num_point = xyz.get_shape()[1].value
        l0_xyz = xyz
        l0_points = None # do not use color for shift prediction

        if return_fullfea:
            new_xyz = tf.concat((pc_seed, xyz), 1)
        else:
            new_xyz = pc_seed

        # Layer 1
        l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz, l0_points, npoint=2048, radius=0.2, nsample=32, mlp=[32,32,64], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer1')
        l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz, l1_points, npoint=512, radius=0.4, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer2')
        l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz, l2_points, npoint=128, radius=0.8, nsample=32, mlp=[128,128,256], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer3')
        l4_xyz, l4_points, l4_indices = pointnet_sa_module(l3_xyz, l3_points, npoint=32, radius=1.6, nsample=32, mlp=[256,256,512], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer4')

        # Feature Propagation layers
        l3_points = pointnet_fp_module(l3_xyz, l4_xyz, l3_points, l4_points, [256,256], is_training, bn_decay, scope='fa_layer1')
        l2_points = pointnet_fp_module(l2_xyz, l3_xyz, l2_points, l3_points, [256,256], is_training, bn_decay, scope='fa_layer2')
        l1_points = pointnet_fp_module(l1_xyz, l2_xyz, l1_points, l2_points, [256,128], is_training, bn_decay, scope='fa_layer3')
        l0_points = pointnet_fp_module(new_xyz, l1_xyz, None, l1_points, [128,128,128], is_training, bn_decay, scope='fa_layer4')

        # FC layers
        net = tf_util.conv1d(l0_points, 4, 1,
                                    padding='VALID', stride=1, scope='conv_shift_pred', activation_fn=None)
        if return_fullfea:
            shift_pred_seed_4d, shift_pred_full_4d = tf.split(net, [npoint_seed, num_point], axis=1)
            end_points['shift_pred_full_4d'] = shift_pred_full_4d
        else:
            shift_pred_seed_4d = net

        end_points['pc_seed'] = pc_seed
        end_points['shift_pred_seed_4d'] = shift_pred_seed_4d
        end_points['ind_seed'] = ind_seed

        return end_points

def sem_net(xyz, points, npoint_sem, num_category, ind_seed, end_points, scope, is_training, bn_decay=None, return_fullfea=False, mode='training'):
    ''' Encode multiple context.
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint_sem: int32 -- #points to sample for fast training
            num_category: int32 -- #output category
            ind_seed: (batch_size, npoint_seed) sampling index of seed points
        Return:
            sem_fea_seed: (batch_size, npoint_seed, nfea)
            sem_fea: (batch_size, npoint_sem, nfea)
            sem_fea_full: (batch_size, ndataset, nfea)
            sem_class_logits: (batch_size, npoint_sem, num_category)
            ind_sem: (batch_size, npoint_sem) sampling index of sem points
    '''
    with tf.variable_scope(scope) as sc:
        batch_size = xyz.get_shape()[0].value
        num_point = xyz.get_shape()[1].value
        npoint_seed = ind_seed.get_shape()[1].value

        ind_sem = farthest_point_sample(npoint_sem, xyz) # (batch_size, npoint_sem)
        new_xyz_sem = gather_point(xyz, ind_sem) # (batch_size, npoint_sem, 3)
        new_points_sem = gather_point(points, ind_sem)
        end_points['ind_sem'] = ind_sem

        new_xyz_seed = gather_point(xyz, ind_seed) # (batch_size, npoint_seed, 3)
        new_points_seed = gather_point(points, ind_seed)

        if return_fullfea:
            new_xyz = tf.concat((new_xyz_seed, new_xyz_sem, xyz), 1)
            new_points = tf.concat((new_points_seed, new_points_sem, points), 1)
        else:
            new_xyz = tf.concat((new_xyz_seed, new_xyz_sem), 1)
            new_points = tf.concat((new_points_seed, new_points_sem), 1)
        
        l0_xyz = xyz
        l0_points = points

        # Layer 1
        l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz, l0_points, npoint=2048, radius=0.2, nsample=32, mlp=[32,32,64], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer1')
        l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz, l1_points, npoint=512, radius=0.4, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer2')
        l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz, l2_points, npoint=128, radius=0.8, nsample=32, mlp=[128,128,256], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer3')
        l4_xyz, l4_points, l4_indices = pointnet_sa_module(l3_xyz, l3_points, npoint=32, radius=1.6, nsample=32, mlp=[256,256,512], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer4')

        # with FPN
        if return_fullfea:
            end_points['sem_fea_full_l4'] = pointnet_fp_module(xyz, l4_xyz, points, l4_points, [], is_training, bn_decay, scope='fa_layer1_fpn')
            end_points['sem_fea_full_l3'] = pointnet_fp_module(xyz, l3_xyz, points, l3_points, [], is_training, bn_decay, scope='fa_layer2_fpn')
            end_points['sem_fea_full_l2'] = pointnet_fp_module(xyz, l2_xyz, points, l2_points, [], is_training, bn_decay, scope='fa_layer3_fpn')
            end_points['sem_fea_full_l1'] = pointnet_fp_module(xyz, l1_xyz, points, l1_points, [], is_training, bn_decay, scope='fa_layer4_fpn')

        # Feature Propagation layers
        l3_points = pointnet_fp_module(l3_xyz, l4_xyz, l3_points, l4_points, [256,256], is_training, bn_decay, scope='fa_layer1')
        l2_points = pointnet_fp_module(l2_xyz, l3_xyz, l2_points, l3_points, [256,256], is_training, bn_decay, scope='fa_layer2')
        l1_points = pointnet_fp_module(l1_xyz, l2_xyz, l1_points, l2_points, [256,128], is_training, bn_decay, scope='fa_layer3')
        l0_points = pointnet_fp_module(new_xyz, l1_xyz, new_points, l1_points, [128,128,128], is_training, bn_decay, scope='fa_layer4')

        # FC layers
        net = tf_util.conv1d(l0_points, 128, 1, padding='VALID', bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay)
        if return_fullfea:
            sem_fea_seed, sem_fea, sem_fea_full = tf.split(net, [npoint_seed, npoint_sem, num_point], axis=1)
            end_points['sem_fea_seed'] = sem_fea_seed
            end_points['sem_fea'] = sem_fea
            end_points['sem_fea_full'] = sem_fea_full
        else:
            sem_fea_seed, sem_fea = tf.split(net, [npoint_seed, npoint_sem], axis=1)
            end_points['sem_fea_seed'] = sem_fea_seed
            end_points['sem_fea'] = sem_fea

        if mode=='training':
            net = end_points['sem_fea']
        elif mode=='inference':
            net = end_points['sem_fea_full']
        net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp1')
        sem_class_logits = tf_util.conv1d(net, num_category, 1, padding='VALID', activation_fn=None, scope='fc2')
        end_points['sem_class_logits'] = sem_class_logits

        return end_points


def pn2_fea_extractor(xyz, points, scope, is_training, bn_decay=None):
    ''' Encode multiple context.
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
        Return:
            new_points: (batch_size, ndataset, channel_out) TF tensor
    '''
    with tf.variable_scope(scope) as sc:
        batch_size = xyz.get_shape()[0].value
        num_point = xyz.get_shape()[1].value
        l0_xyz = xyz
        l0_points = points

        # Layer 1
        l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz, l0_points, npoint=2048, radius=0.2, nsample=32, mlp=[32,32,64], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer1')
        l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz, l1_points, npoint=512, radius=0.4, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer2')
        l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz, l2_points, npoint=128, radius=0.8, nsample=32, mlp=[128,128,256], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer3')

        # Feature Propagation layers
        l2_points = pointnet_fp_module(l2_xyz, l3_xyz, l2_points, l3_points, [256,128], is_training, bn_decay, scope='fa_layer1')
        l1_points = pointnet_fp_module(l1_xyz, l2_xyz, l1_points, l2_points, [128,64], is_training, bn_decay, scope='fa_layer2')
        new_points = pointnet_fp_module(l0_xyz, l1_xyz, l0_points, l1_points, [64,64,64], is_training, bn_decay, scope='fa_layer3')

        return new_points


def single_encoding_net(pc, mlp_list, mlp_list2, scope, is_training, bn_decay):
    ''' The encoding network for instance
    Input:
        pc: [B, N, 3]
    Return:
        net: [B, nfea]
    '''
    with tf.variable_scope(scope) as myscope:
        net = tf.expand_dims(pc, 2)
        for i,num_out_channel in enumerate(mlp_list):
            net = tf_util.conv2d(net, num_out_channel, [1,1],
                                 padding='VALID', stride=[1,1],
                                 bn=True, is_training=is_training,
                                 scope='conv%d'%i, bn_decay=bn_decay)
        net = tf.reduce_max(net, axis=[1])
        net = tf.squeeze(net, 1)
        for i,num_out_channel in enumerate(mlp_list2):
            net = tf_util.fully_connected(net, num_out_channel, bn=True, is_training=is_training,
                                          scope='fc%d'%i, bn_decay=bn_decay)
        return net

def fea_trans_net(input_fea, mlp_list, scope, is_training, bn_decay):
    with tf.variable_scope(scope) as myscope:
        net = input_fea
        nlayer = len(mlp_list)
        for i,num_out_channel in enumerate(mlp_list):
            if i<nlayer-1:
                net = tf_util.conv1d(net, num_out_channel, 1, padding='VALID', bn=True, is_training=is_training,
                                     scope='conv%d'%i, bn_decay=bn_decay)
            else:
                net = tf_util.conv1d(net, num_out_channel, 1, padding='VALID', activation_fn=None, scope='conv%d'%i)
        return net

def sample(mean, log_var):
    # Sample z
    z = mean + tf.exp(log_var/2.0) * tf.random_normal(tf.shape(mean), 0, 1, dtype=tf.float32)
    return z

def decoding_net(feat, num_point, scope, is_training, bn_decay):
    ''' The decoding network for shape generation
    Input:
        feat: [B, nsmp, nfea]
    Return:
        pc: [B, nsmp, num_point, 3]
    '''
    with tf.variable_scope(scope) as myscope:
        nsmp = feat.get_shape()[1].value
        nfea = feat.get_shape()[2].value
        feat = tf.reshape(feat, [-1, nfea])
        # UPCONV Decoder
        if num_point<=3072 and num_point>1536:
            conv_feat = tf.expand_dims(tf.expand_dims(feat, 1),1)
            net = tf_util.conv2d_transpose(conv_feat, 512, kernel_size=[2,2], stride=[1,1], padding='VALID', scope='upconv1', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 256, kernel_size=[3,3], stride=[1,1], padding='VALID', scope='upconv2', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 256, kernel_size=[4,4], stride=[2,2], padding='VALID', scope='upconv3', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 128, kernel_size=[5,5], stride=[3,3], padding='VALID', scope='upconv4', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 3, kernel_size=[1,1], stride=[1,1], padding='VALID', scope='upconv5', activation_fn=None)
            num_point_conv = 1024
        elif num_point<=1536 and num_point>896:
            conv_feat = tf.expand_dims(tf.expand_dims(feat, 1),1)
            net = tf_util.conv2d_transpose(conv_feat, 512, kernel_size=[2,2], stride=[1,1], padding='VALID', scope='upconv1', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 256, kernel_size=[2,2], stride=[1,1], padding='VALID', scope='upconv2', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 256, kernel_size=[3,3], stride=[2,2], padding='VALID', scope='upconv3', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 128, kernel_size=[4,4], stride=[3,3], padding='VALID', scope='upconv4', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 3, kernel_size=[1,1], stride=[1,1], padding='VALID', scope='upconv5', activation_fn=None)
            num_point_conv = 484
        elif num_point<=896 and num_point>384:
            conv_feat = tf.expand_dims(tf.expand_dims(feat, 1),1)
            net = tf_util.conv2d_transpose(conv_feat, 512, kernel_size=[3,3], stride=[1,1], padding='VALID', scope='upconv1', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 256, kernel_size=[3,3], stride=[2,2], padding='VALID', scope='upconv2', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 128, kernel_size=[4,4], stride=[2,2], padding='VALID', scope='upconv3', bn=True, bn_decay=bn_decay, is_training=is_training)
            net = tf_util.conv2d_transpose(net, 3, kernel_size=[1,1], stride=[1,1], padding='VALID', scope='upconv4', activation_fn=None)
            num_point_conv = 256
        else:
            raise('Exception')
        pc_upconv = tf.reshape(net, [-1, num_point_conv, 3])

        num_point_fc = num_point - num_point_conv
        # FC Decoder
        net = tf_util.fully_connected(feat, 512, bn=True, is_training=is_training, scope='de_fc2', bn_decay=bn_decay)
        net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, scope='de_fc3', bn_decay=bn_decay)
        net = tf_util.fully_connected(net, num_point_fc*3, activation_fn=None, scope='de_fc4')
        pc_fc = tf.reshape(net, [-1, num_point_fc, 3])
        # Merge
        pc = tf.concat([pc_upconv, pc_fc], axis=1)
        pc = tf.reshape(pc, [-1, nsmp, num_point, 3])
        return pc

def shape_proposal_net(pc, color, pc_ins, group_label, group_indicator, num_category, scope, is_training, bn_decay=None, nsmp=128, return_fullfea=False, mode='training'):
    ''' Shape proposal generation
    Inputs:
        pc: [B, NUM_POINT, 3]
        color: [B, NUM_POINT, 3]
        pc_ins: [B, NUM_GROUP, NUM_POINT_INS, 3], in world coord sys
        group_label: [B, NUM_POINT]
        group_indicator: [B, NUM_GROUP]
    Returns:
        fb_logits: [B, NUM_SAMPLE, 2] confidence logits (before softmax)
        fb_prob: [B, NUM_SAMPLE, 2] confidence probabilities
        bbox_ins: [B, NUM_SAMPLE, (x, y, z, l, w, h)]
        entity_fea: [B, NUM_POINT, nfea] entity feature for each point
        center_pos: [B, NUM_POINT, 3] center coordinate for each point, in world coord sys
    '''
    with tf.variable_scope(scope) as myscope:
        # Parameter extraction
        batch_size = pc.get_shape()[0].value
        ngroup = pc_ins.get_shape()[1].value
        nsmp_ins = pc_ins.get_shape()[2].value
        end_points = {}

        # Shift prediction, ind_seed [B, nsmp], shift_pred_seed [B, nsmp, 3]
        end_points = shift_pred_net(pc, color, nsmp, end_points, 'shift_predictor', is_training, bn_decay=bn_decay, return_fullfea=return_fullfea)
        pc_seed = end_points['pc_seed']
        shift_pred_seed_4d = end_points['shift_pred_seed_4d']
        ind_seed = end_points['ind_seed']
        shift_pred_seed = tf.multiply(shift_pred_seed_4d[:,:,:3], shift_pred_seed_4d[:,:,3:])

        # Semantic prediction, sem_fea_seed [B, nsmp, nfea]
        end_points = sem_net(pc, color, 1024, num_category, ind_seed, end_points, 'sem_predictor', is_training, bn_decay=bn_decay, return_fullfea=return_fullfea, mode=mode)
        sem_fea_seed = end_points['sem_fea_seed']

        # Encode instance, pcfea_ins_centered [B, ngroup, nfea_ins], pc_ins_center [B, ngroup, 1, 3]
        pc_ins_center = (tf.reduce_max(pc_ins, 2, keep_dims=True)+tf.reduce_min(pc_ins, 2, keep_dims=True))/2 # [B, ngroup, 1, 3] -> requires random padding for pc_ins generation
        pc_ins_centered = pc_ins-pc_ins_center
        idx = tf.where(tf.greater(group_indicator, 0))
        pc_ins_centered_list = tf.gather_nd(pc_ins_centered, idx)
        pcfea_ins_centered_list = single_encoding_net(pc_ins_centered_list, [64, 256, 512], [256], 'instance_encoder', is_training, bn_decay)
        nfea_ins = pcfea_ins_centered_list.get_shape()[1].value
        pcfea_ins_centered = tf.scatter_nd(tf.cast(idx,tf.int32), pcfea_ins_centered_list, tf.constant([batch_size, ngroup, nfea_ins])) # [B, ngroup, nfea_ins]
        
        # Collect instance feature for seed points [B, nsmp, nfea_seed]
        idx = tf.where(tf.greater_equal(ind_seed,0))
        ind_seed_aug = tf.concat((tf.expand_dims(tf.cast(idx[:,0],tf.int32),-1),tf.reshape(ind_seed,[-1,1])),1)
        group_label_seed = tf.reshape(tf.gather_nd(group_label, ind_seed_aug), [-1, nsmp]) # [B, nsmp]
        idx = tf.where(tf.greater_equal(group_label_seed,0))
        group_label_seed_aug = tf.concat((tf.expand_dims(tf.cast(idx[:,0],tf.int32),-1),tf.reshape(group_label_seed,[-1,1])),1)
        pcfea_ins_seed = tf.reshape(tf.gather_nd(pcfea_ins_centered, group_label_seed_aug), [-1, nsmp, nfea_ins])
        pc_ins_centered_seed = tf.reshape(tf.gather_nd(pc_ins_centered, group_label_seed_aug), [-1, nsmp, nsmp_ins, 3])
        pc_ins_center_seed = tf.reshape(tf.gather_nd(pc_ins_center, group_label_seed_aug), [-1, nsmp, 1, 3])
       
        # Encode context, pcfea_seed [B, nsmp, nfea_context]
        _, pcfea_seed, _, _ = multi_encoding_net(pc, color, nsmp, [0.5,1.0,1.5], [256,256,512], [[64,128,256], [64,128,256], [64,128,256]], [], is_training, bn_decay, scope='context_encoder', use_xyz=True, output_shift=False, shift_pred=tf.stop_gradient(shift_pred_seed), fps_idx=ind_seed)

        # Compute foreground/background score [B, nsmp, 2]
        fb_logits = fea_trans_net(pcfea_seed, [256, 64, 2], 'fb_logits', is_training, bn_decay)
        fb_prob = tf.nn.softmax(fb_logits, -1)

        # Compute mu and sigma [B, nsmp, 512]
        mu_sigma_c = fea_trans_net(tf.concat((sem_fea_seed, pcfea_seed), axis=-1), [256, 512, 512], 'mu_sigma_c', is_training, bn_decay)
        mu_sigma_x = fea_trans_net(tf.concat((sem_fea_seed, pcfea_seed, pcfea_ins_seed), axis=-1), [256, 512, 512], 'mu_sigma_x', is_training, bn_decay)
        
        # Sample z [B, nsmp, 256]
        mean = mu_sigma_x[:,:,:256]
        log_var = mu_sigma_x[:,:,256:]
        log_var = tf.clip_by_value(log_var, -10.0, 1.0)
        cmean = mu_sigma_c[:,:,:256]
        clog_var = mu_sigma_c[:,:,256:]
        clog_var = tf.clip_by_value(clog_var, -10.0, 1.0)
        zi = sample(mean, log_var)
        zc = cmean
        z = tf.cond(is_training, lambda: zi, lambda: zc)
        
        # Decode shapes pc [B, nsmp, nsmp_ins, 3]
        gcfeat = tf_util.conv1d(pcfea_seed, 256, 1, padding='VALID', bn=True, is_training=is_training,
                                scope='dec_fc', bn_decay=bn_decay)
        feat = tf.concat((z, gcfeat), axis=-1)
        pc_ins_pred = decoding_net(feat, nsmp_ins, 'decoder', is_training=is_training, bn_decay=bn_decay)
        pc_ins_pred = pc_ins_pred + tf.stop_gradient(tf.expand_dims(shift_pred_seed, 2))

        # Collect bbox for reconstructions
        pc_ins_pred_world_coord = pc_ins_pred + tf.expand_dims(pc_seed, 2)
        bbox_ins_pred = tf.concat(((tf.reduce_max(pc_ins_pred_world_coord, 2)+tf.reduce_min(pc_ins_pred_world_coord, 2))/2, 
                                  tf.reduce_max(pc_ins_pred_world_coord, 2)-tf.reduce_min(pc_ins_pred_world_coord, 2)), 2) # [B, nsmp, 6] -> center + l,w,h

        # Propagate seed feature and center position
        if return_fullfea:
            entity_fea = pointnet_fp_module(pc, pc_seed, None, pcfea_seed, [], is_training=False, bn_decay=None, scope='entity_fea_prop', bn=False)
            shift_pred = tf.multiply(end_points['shift_pred_full_4d'][:,:,:3], end_points['shift_pred_full_4d'][:,:,3:])
            center_pos = pc+shift_pred
            end_points['entity_fea'] = entity_fea # [B, N, 256] entity feature of each point
            end_points['center_pos'] = center_pos # [B, N, 3] center location in the world coord sys

        # Store end_points
        end_points['shift_pred_seed'] = shift_pred_seed # [B, nsmp, 3], offset from seed point to ins center
        end_points['shift_pred_seed_4d'] = shift_pred_seed_4d # [B, nsmp, 4], offset from seed point to ins center
        end_points['pc_seed'] = pc_seed # [B, nsmp, 3], seed point coordinate in world coord sys
        end_points['ind_seed'] = ind_seed # [B, nsmp], seed index
        end_points['pc_ins_centered_seed'] = pc_ins_centered_seed # [B, nsmp, nsmp_ins, 3], centered gt instance point cloud for each seed
        end_points['pc_ins_center_seed'] = pc_ins_center_seed # [B, nsmp, 1, 3], gt instance center for each seed in world coord sys
        end_points['mean'] = mean # [B, nsmp, 256]
        end_points['log_var'] = log_var
        end_points['cmean'] = cmean
        end_points['clog_var'] = clog_var
        end_points['fb_logits'] = fb_logits # [B, nsmp, 2] foreground/backgroud logits
        end_points['fb_prob'] = fb_prob # [B, nsmp, 2] foreground/background probability
        end_points['pc_ins_pred'] = pc_ins_pred # [B, nsmp, nsmp_ins, 3], in local sys, needs to add pc_seed
        end_points['bbox_ins_pred'] = bbox_ins_pred # [B, nsmp, 6]

        return end_points

def nms_3d(boxes, scores, pre_nms_limit, max_output_size, iou_threshold=0.5, score_threshold=float('-inf')):
    ''' Non maximum suppression in 3D
    Inputs:
        boxes: [B, N, 6] center + l,w,h
        scores: [B, N] prob between 0 and 1
    Outputs:
        selected_indices: [B, M]
    '''

    batch_size = scores.shape[0]    
    num_box_input = scores.shape[1]    
    sidx = np.argsort(-scores, 1) # [B, N] from large to small
    selected_indices = -np.ones((batch_size, max_output_size), dtype=np.int32) # [B, M]
    for i in range(batch_size):
        cursidx = sidx[i,:]
        curscores = scores[i,:]
        curvolume = boxes[i,:,3]*boxes[i,:,4]*boxes[i,:,5]
        if pre_nms_limit>0:
            cursidx = cursidx[:pre_nms_limit]
        cursidx = cursidx[curscores[cursidx]>score_threshold]
        count = 0
        while len(cursidx)>0 and count<max_output_size:
            selected_indices[i,count] = cursidx[0]
            count += 1
            vA = np.maximum(boxes[i,[cursidx[0]],:3]-boxes[i,[cursidx[0]],3:]/2, boxes[i,cursidx,:3]-boxes[i,cursidx,3:]/2)
            vB = np.minimum(boxes[i,[cursidx[0]],:3]+boxes[i,[cursidx[0]],3:]/2, boxes[i,cursidx,:3]+boxes[i,cursidx,3:]/2)
            intersection_cube = np.maximum(vB-vA,0)
            intersection_volume = intersection_cube[:,0]*intersection_cube[:,1]*intersection_cube[:,2]
            iou = np.divide(intersection_volume,curvolume[cursidx]+curvolume[cursidx[0]]-intersection_volume+1e-8)
            cursidx = np.delete(cursidx, np.where(iou>iou_threshold)[0])
    return selected_indices

def gather_selection(source, selected_idx, max_selected_size):
    '''
    Inputs:
        source: [B, N, C]
        selected_idx: [B, M], -1 means not selecting anything
    Returns:
        target: [B, M, C], 0 padded
    '''
    batch_size = source.get_shape()[0].value
    fea_size = source.get_shape()[2].value
    pos_idx = tf.cast(tf.where(tf.greater_equal(selected_idx,0)), tf.int32)
    selected_idx_vec = tf.gather_nd(selected_idx, pos_idx)
    target_vec = tf.gather_nd(source, tf.concat((tf.expand_dims(pos_idx[:,0],-1), tf.reshape(selected_idx_vec,[-1,1])),1))
    target = tf.scatter_nd(pos_idx, target_vec, tf.constant([batch_size, max_selected_size, fea_size]))
    return target

def trim_zeros_graph(boxes, name=None):
    """Often boxes are represented with matrices of shape [N, 4] and
    are padded with zeros. This removes zero boxes.

    boxes: [N, 6] matrix of boxes.
    non_zeros: [N] a 1D boolean mask identifying the rows to keep
    """
    non_zeros = tf.cast(tf.reduce_sum(tf.abs(boxes), axis=1), tf.bool)
    boxes = tf.boolean_mask(boxes, non_zeros, name=name)
    return boxes, non_zeros

def batch_slice(inputs, graph_fn, batch_size, names=None):
    """Splits inputs into slices and feeds each slice to a copy of the given
    computation graph and then combines the results. It allows you to run a
    graph on a batch of inputs even if the graph is written to support one
    instance only.
    inputs: list of tensors. All must have the same first dimension length
    graph_fn: A function that returns a TF tensor that's part of a graph.
    batch_size: number of slices to divide the data into.
    names: If provided, assigns names to the resulting tensors.
    """
    if not isinstance(inputs, list):
        inputs = [inputs]
    outputs = []
    for i in range(batch_size):
        inputs_slice = [x[i] for x in inputs]
        output_slice = graph_fn(*inputs_slice)
        if not isinstance(output_slice, (tuple, list)):
            output_slice = [output_slice]
        outputs.append(output_slice)
    # Change outputs from a list of slices where each is
    # a list of outputs to a list of outputs and each has
    # a list of slices
    outputs = list(zip(*outputs))

    if names is None:
        names = [None] * len(outputs)

    result = [tf.stack(o, axis=0, name=n)
              for o, n in zip(outputs, names)]
    if len(result) == 1:
        result = result[0]

    return result

def box_shrink(box, pc):
    ''' Shrink bounding box so that it is tight with respect to pc
    Inputs:
        box: [B, NUM_SAMPLE, 6]
        pc: [B, NUM_POINT, 3]
    Returns:
        box: [B, NUM_SAMPLE, 6]
    '''
    pc_aug = tf.expand_dims(pc, 1) # [B, 1, NUM_POINT, 3]
    box_aug = tf.expand_dims(box, 2) # [B, NUM_SAMPLE, 1, 6]
    box_masks = tf.logical_and(pc_aug>=(box_aug[:,:,:,:3]-box_aug[:,:,:,3:]/2),
                               pc_aug<=(box_aug[:,:,:,:3]+box_aug[:,:,:,3:]/2)) # [B, NUM_SAMPLE, NUM_POINT, 3]
    box_masks = tf.logical_and(tf.logical_and(box_masks[:,:,:,0], box_masks[:,:,:,1]), box_masks[:,:,:,2]) # [B, NUM_SAMPLE, NUM_POINT]
    box_out_masks = 1-tf.cast(tf.expand_dims(box_masks, -1), tf.float32) # [B, NUM_SAMPLE, NUM_POINT, 1]
    gamma = 1e4 # a large number for the box estimation trick
    box_max = tf.reduce_max(pc_aug-gamma*box_out_masks,2) # [B, NUM_SAMPLE, 3]
    box_min = tf.reduce_min(pc_aug+gamma*box_out_masks,2)
    box = tf.concat( ((box_max+box_min)/2, box_max-box_min+1e-3), 2) # [B, NUM_SAMPLE, 6]
    keep = tf.greater(box_max-box_min, 0)
    keep = tf.logical_and(tf.logical_and(keep[:,:,0], keep[:,:,1]), keep[:,:,2])
    keep = tf.expand_dims(tf.cast(keep, tf.float32), -1)
    box = tf.multiply(box, keep)
    return box

def box_refinement(box, gt_box):
    """Compute refinement needed to transform box to gt_box.
    box and gt_box are [N, (center_x, center_y, center_z, l, w, h)]
    """
    box = tf.cast(box, tf.float32)
    gt_box = tf.cast(gt_box, tf.float32)

    dz = (gt_box[:,2] - box[:,2]) / (box[:,5]+1e-8)
    dy = (gt_box[:,1] - box[:,1]) / (box[:,4]+1e-8)
    dx = (gt_box[:,0] - box[:,0]) / (box[:,3]+1e-8)
    dh = tf.log(gt_box[:,5] / (box[:,5]+1e-8))
    dw = tf.log(gt_box[:,4] / (box[:,4]+1e-8))
    dl = tf.log(gt_box[:,3] / (box[:,3]+1e-8))

    result = tf.stack([dz, dy, dx, dh, dw, dl], axis=1)
    return result

def apply_box_delta(box, delta):
    ''' Apply bounding box delta to refine box
    Inputs:
        box: [NUM_SAMPLE, 6]
        delta: [NUM_SAMPLE, 6]
    Returns:
        box_refined: [NUM_SAMPLE, 6]
    '''
    delta = tf.stack([delta[:,2], delta[:,1], delta[:,0], delta[:,5], delta[:,4], delta[:,3]], axis=1)
    box_refined_part2 = tf.multiply(tf.exp(delta[:,3:]), box[:,3:])
    box_refined_part1 = tf.multiply(delta[:,:3], box[:,3:])+box[:,:3]
    box_refined = tf.concat((box_refined_part1, box_refined_part2), axis=1)
    return box_refined

def sample_points_within_box(masks, nsmp):
    '''
    Inputs:
        masks: [nmask, npoint]
        nsmp: scalar
    Returns:
        masks_selection_idx: [nmask, nsmp]
    '''
    if masks.shape[0]==0:
        return np.zeros((0, nsmp), dtype=np.int32)
    else:
        masks_selection_idx = [np.random.choice(np.where(masks[i,:])[0], nsmp, replace=True) if len(np.where(masks[i,:])[0])>0 else np.zeros(nsmp, dtype=np.int32) for i in np.arange(masks.shape[0])]
        masks_selection_idx = np.stack(masks_selection_idx, 0).astype(np.int32)
        return masks_selection_idx

def spn_target_gen(proposals, proposal_seed_class_ids, gt_class_ids, gt_boxes):
    ''' SPN target generation
    Inputs:
        proposals: [NUM_SAMPLE, 6]
        proposal_seed_class_ids: [NUM_SAMPLE] - 0 is background class and 1 is foreground class
        gt_class_ids: [NUM_GROUP] - 0 is background class and 1 is foreground class
        gt_boxes: [NUM_GROUP, 6]
    Returns:
        spn_match: [NUM_SAMPLE], 1 = positive, -1 = negative, 0 = neutral
    '''
    # Remove zero padding
    non_zeros = tf.where(tf.greater(gt_class_ids, 0))[:,0]
    gt_boxes = tf.gather(gt_boxes, non_zeros, axis=0, name="spn_trim_gt_boxes")
    gt_class_ids = tf.gather(gt_class_ids, non_zeros, axis=0, name="spn_trim_gt_class_ids")

    # Compute IoU [NUM_SAMPLE, NUM_GROUP_TRIMMED]
    proposals_aug = tf.expand_dims(proposals, 1)
    gt_boxes_aug = tf.expand_dims(gt_boxes, 0)
    proposal_volume = tf.multiply(tf.multiply(proposals_aug[:,:,3], proposals_aug[:,:,4]), proposals_aug[:,:,5])
    gt_boxes_volume = tf.multiply(tf.multiply(gt_boxes_aug[:,:,3], gt_boxes_aug[:,:,4]), gt_boxes_aug[:,:,5])
    vA = tf.maximum(proposals_aug[:,:,:3]-proposals_aug[:,:,3:]/2, gt_boxes_aug[:,:,:3]-gt_boxes_aug[:,:,3:]/2)
    vB = tf.minimum(proposals_aug[:,:,:3]+proposals_aug[:,:,3:]/2, gt_boxes_aug[:,:,:3]+gt_boxes_aug[:,:,3:]/2)
    intersection_cube = tf.maximum(vB-vA,0)
    intersection_volume = tf.multiply(tf.multiply(intersection_cube[:,:,0], intersection_cube[:,:,1]), intersection_cube[:,:,2])
    ious = tf.divide(intersection_volume,proposal_volume+gt_boxes_volume-intersection_volume+1e-8)

    # Determine positive and negative ROIs
    roi_iou_max = tf.reduce_max(ious, axis=1)
    # 1. Positive ROIs are those with >=0.5 IoU with a GT box and gt class id is 1
    # We also guarantee that the boxes with the largest IoU with a GT box is also treated as positive
    spn_match_positive = tf.logical_and(tf.greater_equal(roi_iou_max, 0.5), tf.equal(proposal_seed_class_ids,1))
    masked_ious = tf.multiply(ious, tf.cast(tf.equal(tf.expand_dims(proposal_seed_class_ids,-1),1), tf.float32))
    positive_aug_idx = tf.argmax(masked_ious,0)
    positive_aug_idx = tf.boolean_mask(positive_aug_idx, tf.reduce_max(masked_ious, 0)>0)
    updates = tf.cast(tf.ones_like(positive_aug_idx), tf.float32)
    spn_match_positive_aug = tf.cond(
        tf.size(positive_aug_idx)>0,
        true_fn = lambda: tf.scatter_nd(positive_aug_idx, updates, spn_match_positive.shape),
        false_fn = lambda: tf.constant(0.0))
    spn_match_positive = tf.greater(tf.cast(spn_match_positive, tf.float32)+spn_match_positive_aug, 0)
    # 2. Negative ROIs are those with < 0.5 with every GT box and not a postive ROI
    spn_match_negative = tf.logical_and(tf.less(roi_iou_max, 0.5), tf.logical_not(spn_match_positive))

    spn_match = tf.cast(spn_match_positive, tf.float32)-tf.cast(spn_match_negative, tf.float32)

    return spn_match


def detection_target_gen(proposals, gt_class_ids, gt_boxes, gt_masks, pc, config):
    ''' Generate detection targets for training
    Inputs:
        proposals: [SPN_NMS_MAX_SIZE, 6], zero padded
        gt_class_ids: [NUM_GROUP] - 0 is background class and objects start from 1
        gt_boxes: [NUM_GROUP, 6]
        gt_masks: [NUM_POINT, NUM_GROUP]
        pc: [NUM_POINT, 3]
    Returns:
        rois: [TRAIN_ROIS_PER_IMAGE, 6], zero padded
        roi_gt_class_ids: [TRAIN_ROIS_PER_IMAGE] - 0 is invalid class and used for padding, zero padded
        deltas: [TRAIN_ROIS_PER_IMAGE, 6], zero padded
        masks_selection_idx: [TRAIN_ROIS_PER_IMAGE, NUM_POINT_INS_MASK] - which points are selected from the initial point cloud, zero padded
        masks: [TRAIN_ROIS_PER_IMAGE, NUM_POINT_INS_MASK] - binary mask, zero padded
    '''
    # Remove zero padding
    proposals, _ = trim_zeros_graph(proposals, name='trim_proposals')
    gt_boxes, non_zeros = trim_zeros_graph(gt_boxes, name="trim_gt_boxes")
    gt_class_ids = tf.boolean_mask(gt_class_ids, non_zeros,
                                   name="trim_gt_class_ids")
    gt_masks = tf.gather(gt_masks, tf.where(non_zeros)[:, 0], axis=1,
                         name="trim_gt_masks")

    # Remove empty proposals
    pc_aug = tf.expand_dims(pc,1) # [N, 1, 3]
    proposals_aug = tf.expand_dims(proposals, 0) # [1, NP, 6]
    roi_masks = tf.logical_and(pc_aug>=(proposals_aug[:,:,:3]-proposals_aug[:,:,3:]/2),
                               pc_aug<=(proposals_aug[:,:,:3]+proposals_aug[:,:,3:]/2))
    roi_masks = tf.logical_and(tf.logical_and(roi_masks[:,:,0], roi_masks[:,:,1]), roi_masks[:,:,2]) # [N, NP]
    non_empty_idx = tf.where(tf.greater(tf.reduce_sum(tf.cast(roi_masks, tf.float32), 0), 0))[:,0]
    roi_masks = tf.gather(roi_masks, non_empty_idx, axis=1) # [N, NP']
    proposals = tf.gather(proposals, non_empty_idx, axis=0) # [NP', 6]

    # Compute IoU [n_proposal, n_gt_boxes]
    proposals_aug = tf.expand_dims(proposals, 1)
    gt_boxes_aug = tf.expand_dims(gt_boxes, 0)
    proposal_volume = tf.multiply(tf.multiply(proposals_aug[:,:,3], proposals_aug[:,:,4]), proposals_aug[:,:,5])
    gt_boxes_volume = tf.multiply(tf.multiply(gt_boxes_aug[:,:,3], gt_boxes_aug[:,:,4]), gt_boxes_aug[:,:,5])
    vA = tf.maximum(proposals_aug[:,:,:3]-proposals_aug[:,:,3:]/2, gt_boxes_aug[:,:,:3]-gt_boxes_aug[:,:,3:]/2)
    vB = tf.minimum(proposals_aug[:,:,:3]+proposals_aug[:,:,3:]/2, gt_boxes_aug[:,:,:3]+gt_boxes_aug[:,:,3:]/2)
    intersection_cube = tf.maximum(vB-vA,0)
    intersection_volume = tf.multiply(tf.multiply(intersection_cube[:,:,0], intersection_cube[:,:,1]), intersection_cube[:,:,2])
    ious = tf.divide(intersection_volume,proposal_volume+gt_boxes_volume-intersection_volume+1e-8)
    
    # Determine positive and negative ROIs
    roi_iou_max = tf.reduce_max(ious, axis=1)
    # 1. Positive ROIs are those with >= 0.5 IoU with a GT box
    positive_indices = tf.where(roi_iou_max >= 0.5)[:, 0]
    # 2. Negative ROIs are those with < 0.5 with every GT box
    negative_indices = tf.where(roi_iou_max < 0.5)[:, 0]

    # Subsample ROIs. Aim for 33% positive
    # Positive ROIs
    positive_count = int(config.TRAIN_ROIS_PER_IMAGE *
                         config.ROI_POSITIVE_RATIO)
    positive_indices = tf.random_shuffle(positive_indices)[:positive_count]
    positive_count = tf.shape(positive_indices)[0]
    # Negative ROIs. Add enough to maintain positive:negative ratio.
    r = 1.0 / config.ROI_POSITIVE_RATIO
    negative_count = tf.cast(r * tf.cast(positive_count, tf.float32), tf.int32) - positive_count
    negative_indices = tf.random_shuffle(negative_indices)[:negative_count]
    # Gather selected ROIs [POSITIVE_COUNT/NEGATIVE_COUNT, 6]
    positive_rois = tf.gather(proposals, positive_indices)
    negative_rois = tf.gather(proposals, negative_indices)

    # Assign positive ROIs to GT boxes. roi_gt_boxes: [POSITIVE_COUNT, 6], roi_gt_class_ids: [POSITIVE_COUNT]
    positive_ious = tf.gather(ious, positive_indices)
    roi_gt_box_assignment = tf.cond(
        tf.greater(tf.shape(positive_ious)[1], 0),
        true_fn = lambda: tf.argmax(positive_ious, axis=1),
        false_fn = lambda: tf.cast(tf.constant([]),tf.int64)
    )
    roi_gt_boxes = tf.gather(gt_boxes, roi_gt_box_assignment)
    roi_gt_class_ids = tf.gather(gt_class_ids, roi_gt_box_assignment)

    # Compute bbox refinement for positive ROIs. delta: [POSITIVE_COUNT, 6]
    deltas = box_refinement(positive_rois, roi_gt_boxes)
    deltas /= config.BBOX_STD_DEV

    # Compute mask targets
    roi_gt_masks = tf.cast(tf.gather(gt_masks, roi_gt_box_assignment, axis=1), tf.bool) # [NUM_POINT, POSITIVE_COUNT]
    positive_roi_masks = tf.gather(roi_masks, positive_indices, axis=1)
    masks_full = tf.transpose(tf.logical_and(positive_roi_masks, roi_gt_masks)) # [POSITIVE_COUNT, NUM_POINT]
    masks_selection_idx = tf.stop_gradient(tf.py_func(sample_points_within_box, 
        [tf.transpose(positive_roi_masks), config.NUM_POINT_INS_MASK], tf.int32)) # [POSITIVE_COUNT, NUM_POINT_INS_MASK]
    smp_idx = tf.reshape(tf.tile(tf.reshape(tf.range(positive_count),[-1,1]),[1, config.NUM_POINT_INS_MASK]),[-1,1])
    smp_idx = tf.concat((smp_idx, tf.reshape(masks_selection_idx,[-1,1])),1)
    masks = tf.reshape(tf.gather_nd(masks_full, smp_idx),[-1, config.NUM_POINT_INS_MASK]) # [POSITIVE_COUNT, NUM_POINT_INS_MASK]

    # Append negative ROIs and pad bbox deltas and masks that
    # are not used for negative ROIs with zeros.
    rois = tf.concat([positive_rois, negative_rois], axis=0)
    N = tf.shape(negative_rois)[0]
    P = tf.maximum(config.TRAIN_ROIS_PER_IMAGE - tf.shape(rois)[0], 0)
    rois = tf.pad(rois, [(0, P), (0, 0)])
    roi_gt_class_ids = tf.pad(roi_gt_class_ids, [(0, N + P)])
    deltas = tf.pad(deltas, [(0, N + P), (0, 0)])
    masks_selection_idx = tf.pad(masks_selection_idx, [[0, N + P], (0, 0)])
    masks = tf.pad(masks, [[0, N + P], (0, 0)])

    return rois, roi_gt_class_ids, deltas, masks_selection_idx, masks

def mask_selection_gen(proposals, pc, num_rois, config, empty_removal=True):
    ''' Generate detection targets for training
    Inputs:
        proposals: [SPN_NMS_MAX_SIZE, 6], zero padded
        pc: [NUM_POINT, 3]
    Returns:
        proposals: [NUM_ROIS, 6]
        masks_selection_idx: [NUM_ROIS, NUM_POINT_INS_MASK] - which points are selected from the initial point cloud, zero padded
    '''
    # Remove zero padding
    proposals, _ = trim_zeros_graph(proposals, name='trim_proposals')

    # Remove empty proposals
    pc_aug = tf.expand_dims(pc,1) # [N, 1, 3]
    proposals_aug = tf.expand_dims(proposals, 0) # [1, NP, 6]
    roi_masks = tf.logical_and(pc_aug>=(proposals_aug[:,:,:3]-proposals_aug[:,:,3:]/2-1e-3),
                               pc_aug<=(proposals_aug[:,:,:3]+proposals_aug[:,:,3:]/2+1e-3))
    roi_masks = tf.logical_and(tf.logical_and(roi_masks[:,:,0], roi_masks[:,:,1]), roi_masks[:,:,2]) # [N, NP]
    if empty_removal:
        non_empty_idx = tf.where(tf.greater(tf.reduce_sum(tf.cast(roi_masks, tf.float32), 0), 0))[:,0]
        roi_masks = tf.gather(roi_masks, non_empty_idx, axis=1) # [N, NP']
        proposals = tf.gather(proposals, non_empty_idx, axis=0) # [NP', 6]
    proposals_count = tf.shape(proposals)[0]

    # Generate mask selection index
    masks_selection_idx = tf.stop_gradient(tf.py_func(sample_points_within_box, 
        [tf.transpose(roi_masks), config.NUM_POINT_INS_MASK], tf.int32)) # [NP', NUM_POINT_INS_MASK]

    # Append negative ROIs and pad bbox deltas and masks that
    # are not used for negative ROIs with zeros.
    P = tf.maximum(num_rois - proposals_count, 0)
    masks_selection_idx = tf.pad(masks_selection_idx, [[0, P], (0, 0)])
    proposals = tf.pad(proposals, [(0, P), (0, 0)])

    return proposals, masks_selection_idx

def points_cropping(pc, pc_fea, pc_center, rois, masks_selection_idx, num_rois, num_point_per_roi, normalize_crop_region=True):
    ''' Crop points for network heads, in analogy to ROIAlign
    Inputs:
        pc: [B, NUM_POINT, 3]
        pc_fea: [B, NUM_POINT, NFEA]
        pc_center: [B, NUM_POINT, 3]
        rois: [B, NUM_ROIS, 6], zero padded
        masks_selection_idx: [B, NUM_ROIS, NUM_POINT_PER_ROI]
    Returns:
        pc_fea_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, NFEA]
        pc_center_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, 3]
        pc_coord_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, 3]
    '''
    batch_size = pc.get_shape()[0].value
    smp_idx = tf.reshape(tf.tile(tf.reshape(tf.range(batch_size),[-1,1]),[1, num_rois*num_point_per_roi]),[-1,1])
    smp_idx = tf.concat((smp_idx, tf.reshape(masks_selection_idx,[-1,1])),1)
    pc_fea_cropped = tf.reshape(tf.gather_nd(pc_fea, smp_idx),[batch_size, num_rois, num_point_per_roi, -1])
    pc_center_cropped = tf.reshape(tf.gather_nd(pc_center, smp_idx),[batch_size, num_rois, num_point_per_roi, -1])
    pc_coord_cropped_unnormalized = tf.reshape(tf.gather_nd(pc, smp_idx),[batch_size, num_rois, num_point_per_roi, -1])
    pc_coord_cropped = pc_coord_cropped_unnormalized

    # convert world coord to local
    rois_center = tf.expand_dims(rois[:,:,:3], 2)
    pc_coord_cropped = pc_coord_cropped-rois_center
    pc_center_cropped = pc_center_cropped-rois_center
    if normalize_crop_region:
        # scale box to [1,1,1]
        rois = rois+tf.cast(tf.equal(tf.reduce_sum(rois, 2, keep_dims=True),0),tf.float32)
        rois_size = tf.expand_dims(rois[:,:,3:], 2)
        pc_coord_cropped = tf.divide(pc_coord_cropped, rois_size)
        pc_center_cropped = tf.divide(pc_center_cropped, rois_size)
    return pc_fea_cropped, pc_center_cropped, pc_coord_cropped, pc_coord_cropped_unnormalized

def refine_detections(rois, probs, deltas, pc, fb_prob, sem_prob, config):
    '''Refine classified proposals and filter overlaps and return final
    detections.
    Inputs:
        rois: [NUM_ROIS, 6], zero padded, in world coord sys
        probs: [NUM_ROIS, NUM_CATEGORY] - 0 is background class and objects start from 1
        deltas: [NUM_ROIS, NUM_CATEGORY, 6]
        pc: [NUM_POINT, 3]
        fb_prob: [NUM_ROIS]
        sem_prob: [NUM_ROIS]
    Returns:
        detections: [NUM_DETECTIONS, (center_x, center_y, center_z, l, w, h, class_id, score)]
    '''
    # Class IDs per ROI
    class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
    # Class probability of the top class of each ROI
    indices = tf.stack([tf.range(probs.shape[0]), class_ids], axis=1)
    class_scores = tf.gather_nd(probs, indices) # [NUM_ROIS]
    # Class-specific bounding box deltas
    deltas_specific = tf.gather_nd(deltas, indices) # [NUM_ROIS, 6]
    # Apply bounding box deltas

    # Shape: [NUM_ROIS, (center_x, center_y, center_z, l, w, h)]
    refined_rois = apply_box_delta(rois, deltas_specific * config.BBOX_STD_DEV)
    # Shrink boxes
    if config.SHRINK_BOX:
        refined_rois = tf.squeeze(box_shrink(tf.expand_dims(refined_rois, 0), tf.expand_dims(pc,0)),0)

    # Filter out background boxes
    keep = tf.where(class_ids > 0)[:, 0]
    # Filter out low confidence boxes
    if config.DETECTION_MIN_CONFIDENCE:
        conf_keep = tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[:, 0]
        keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
                                        tf.expand_dims(conf_keep, 0))
        keep = tf.sparse_tensor_to_dense(keep)[0]

    # Apply per-class NMS
    # 1. Prepare variables
    pre_nms_class_ids = tf.gather(class_ids, keep)
    # pre_nms_scores = tf.gather(class_scores, keep) # [NUM_KEEP]
    pre_nms_scores = tf.gather(class_scores*fb_prob*sem_prob, keep) # [NUM_KEEP]
    pre_nms_rois = tf.gather(refined_rois,   keep)
    unique_pre_nms_class_ids = tf.unique(pre_nms_class_ids)[0]

    def nms_keep_map(class_id):
        """Apply Non-Maximum Suppression on ROIs of the given class."""
        # Indices of ROIs of the given class
        ixs = tf.where(tf.equal(pre_nms_class_ids, class_id))[:, 0]
        # Apply NMS

        class_keep = tf.py_func(nms_3d, [tf.expand_dims(tf.gather(pre_nms_rois, ixs),0),
            tf.expand_dims(tf.gather(pre_nms_scores, ixs),0),
            -1, config.DETECTION_MAX_INSTANCES,
            config.DETECTION_NMS_THRESHOLD, float('-inf')], tf.int32)
        class_keep = tf.squeeze(class_keep, 0) # [DETECTION_MAX_INSTANCES], -1 padded
        class_keep = tf.gather(class_keep, tf.where(class_keep > -1)[:,0]) # [<=DETECTION_MAX_INSTANCES], no padding

        # Map indices
        class_keep = tf.gather(keep, tf.gather(ixs, class_keep))
        # Pad with -1 so returned tensors have the same shape
        gap = config.DETECTION_MAX_INSTANCES - tf.shape(class_keep)[0]
        class_keep = tf.pad(class_keep, [(0, gap)],
                            mode='CONSTANT', constant_values=-1)
        # Set shape so map_fn() can infer result shape
        class_keep.set_shape([config.DETECTION_MAX_INSTANCES])
        return class_keep

    # 2. Map over class IDs
    nms_keep = tf.map_fn(nms_keep_map, unique_pre_nms_class_ids,
                         dtype=tf.int64)
    # 3. Merge results into one list, and remove -1 padding
    nms_keep = tf.reshape(nms_keep, [-1])
    nms_keep = tf.gather(nms_keep, tf.where(nms_keep > -1)[:, 0])
    # 4. Compute intersection between keep and nms_keep
    keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
                                    tf.expand_dims(nms_keep, 0))
    keep = tf.sparse_tensor_to_dense(keep)[0]
    # Keep top detections
    roi_count = config.DETECTION_MAX_INSTANCES
    class_scores_keep = tf.gather(class_scores*fb_prob*sem_prob, keep)
    num_keep = tf.minimum(tf.shape(class_scores_keep)[0], roi_count)
    top_ids = tf.nn.top_k(class_scores_keep, k=num_keep, sorted=True)[1]
    keep = tf.gather(keep, top_ids)

    # Arrange output as [N, (center_x, center_y, center_z, l, w, h, class_id, score)]
    detections = tf.concat([
        tf.gather(refined_rois, keep),
        tf.to_float(tf.gather(class_ids, keep))[..., tf.newaxis],
        tf.gather(class_scores, keep)[..., tf.newaxis]
        ], axis=1)

    # Pad with zeros if detections < DETECTION_MAX_INSTANCES
    gap = config.DETECTION_MAX_INSTANCES - tf.shape(detections)[0]
    detections = tf.pad(detections, [(0, gap), (0, 0)], "CONSTANT")
    return detections

def classification_head(pc, pc_fea, num_category, mlp_list, mlp_list2, is_training, bn_decay, scope, bn=True):
    ''' Classification head for both class id prediction and bbox delta regression
    Inputs:
        pc: [B, NUM_ROIS, NUM_POINT_PER_ROI, 3]
        pc_fea: [B, NUM_ROIS, NUM_POINT_PER_ROI, NFEA]
        num_category: scalar
    Returns:
        logits: [B, NUM_ROIS, NUM_CATEGORY]
        probs: [B, NUM_ROIS, NUM_CATEGORY]
        bbox_deltas: [B, NUM_ROIS, NUM_CATEGORY, (dz, dy, dx, log(dh), log(dw), log(dl))]
    '''
    with tf.variable_scope(scope) as myscope:
        num_rois = pc.get_shape()[1].value
        grouped_points = tf.concat((pc_fea, pc), -1)
        for i,num_out_channel in enumerate(mlp_list):
            grouped_points = tf_util.conv2d(grouped_points, num_out_channel, [1, 1],
                                            padding='VALID', stride=[1,1], bn=bn, is_training=is_training,
                                            scope='conv_prev_%d'%i, bn_decay=bn_decay)
        new_points = tf.reduce_max(grouped_points, axis=2)
        for i,num_out_channel in enumerate(mlp_list2):
            new_points = tf_util.conv1d(new_points, num_out_channel, 1,
                                        padding='VALID', stride=1, bn=bn, is_training=is_training,
                                        scope='conv_post_%d'%i, bn_decay=bn_decay)
        logits = tf_util.conv1d(new_points, num_category, 1, padding='VALID',
                                stride=1, scope='conv_classify', activation_fn=None)
        probs = tf.nn.softmax(logits, 2)
        bbox_deltas = tf_util.conv1d(new_points, num_category*6, 1, padding='VALID',
                                     stride=1, scope='conv_bbox_regress', activation_fn=None)
        bbox_deltas = tf.reshape(bbox_deltas, [-1, num_rois, num_category, 6])
        return logits, probs, bbox_deltas

def segmentation_head(pc, pc_fea, num_category, mlp_list, mlp_list2, mlp_list3, is_training, bn_decay, scope, bn=True):
    ''' Segmentation head
    Inputs:
        pc: [B, NUM_ROIS, NUM_POINT_PER_ROI, 3]
        pc_fea: [B, NUM_ROIS, NUM_POINT_PER_ROI, NFEA]
        num_category: scalar
    Returns:
        masks: [B, NUM_ROIS, NUM_POINT_PER_ROI, NUM_CATEGORY]
    '''
    with tf.variable_scope(scope) as myscope:
        num_rois = pc.get_shape()[1].value
        num_point_per_roi = pc.get_shape()[2].value
        grouped_points = tf.concat((pc_fea, pc), -1)
        for i,num_out_channel in enumerate(mlp_list):
            grouped_points = tf_util.conv2d(grouped_points, num_out_channel, [1, 1],
                                            padding='VALID', stride=[1,1], bn=bn, is_training=is_training,
                                            scope='conv_prev_%d'%i, bn_decay=bn_decay)
        local_feat = grouped_points
        for i,num_out_channel in enumerate(mlp_list2):
            grouped_points = tf_util.conv2d(grouped_points, num_out_channel, [1, 1],
                                            padding='VALID', stride=[1,1], bn=bn, is_training=is_training,
                                            scope='conv_%d'%i, bn_decay=bn_decay)
        global_feat = tf.reduce_max(grouped_points, axis=2, keep_dims=True)
        global_feat_expanded = tf.tile(global_feat, [1, 1, num_point_per_roi, 1])
        new_points = tf.concat((global_feat_expanded, local_feat), -1)
        for i,num_out_channel in enumerate(mlp_list3):
            new_points = tf_util.conv2d(new_points, num_out_channel, [1, 1],
                                            padding='VALID', stride=[1,1], bn=bn, is_training=is_training,
                                            scope='conv_post_%d'%i, bn_decay=bn_decay)

        masks = tf_util.conv2d(new_points, num_category, [1, 1], padding='VALID',
                               stride=[1,1], scope='conv_seg', activation_fn=None)
        return masks

def dict_stop_gradient(dict_in):
    keys = dict_in.keys()
    for key in keys:
        dict_in[key] = tf.stop_gradient(dict_in[key])
    return dict_in

def select_segmentation(rpointnet_masks, class_ids):
    ''' Convert segmentation into point cloud label
    Inputs:
        rpointnet_masks: [B, NUM_ROIS, NUM_POINT_PER_ROI, NUM_CATEGORY]
        class_ids: [B, NUM_ROIS]
    Returns:
        rpointnet_mask_selected: [B, NUM_ROIS, NUM_POINT_PER_ROI]
    '''
    batch_size = rpointnet_masks.get_shape()[0].value
    num_rois = rpointnet_masks.get_shape()[1].value
    num_point_per_roi = rpointnet_masks.get_shape()[2].value
    num_category = rpointnet_masks.get_shape()[3].value

    rpointnet_masks = tf.reshape(rpointnet_masks, [-1, num_point_per_roi, num_category])
    rpointnet_masks = tf.transpose(rpointnet_masks, perm=[0, 2, 1]) # [B*NUM_ROIS, NUM_CATEGORY, NUM_POINT_PER_ROI]
    class_ids = tf.cast(tf.reshape(class_ids, [-1]), tf.int32)
    class_ids_aug = tf.stack([tf.range(batch_size*num_rois, dtype=tf.int32), class_ids], 1)
    rpointnet_mask_selected = tf.gather_nd(rpointnet_masks, class_ids_aug) #[-1, NUM_POINT_PER_ROI]
    rpointnet_mask_selected = tf.reshape(rpointnet_mask_selected, [batch_size, num_rois, num_point_per_roi])

    return rpointnet_mask_selected

def unmold_segmentation(rpointnet_masks, rois, class_ids, pc_coord_cropped, pc):
    ''' Convert segmentation into point cloud label
    Inputs:
        rpointnet_masks: [B, NUM_ROIS, NUM_POINT_PER_ROI, NUM_CATEGORY]
        rois: [B, NUM_ROIS, 6]
        class_ids: [B, NUM_ROIS]
        pc_coord_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, 3]
        pc: [B, NUM_POINT, 3]
    Returns:
        rpointnet_mask_unmolded: [B, NUM_ROIS, NUM_POINT]
    '''
    batch_size = rpointnet_masks.get_shape()[0].value
    num_rois = rpointnet_masks.get_shape()[1].value
    num_point_per_roi = rpointnet_masks.get_shape()[2].value
    num_category = rpointnet_masks.get_shape()[3].value
    num_point = pc.get_shape()[1].value

    rpointnet_masks = tf.reshape(rpointnet_masks, [-1, num_point_per_roi, num_category])
    rpointnet_masks = tf.transpose(rpointnet_masks, perm=[0, 2, 1]) # [B*NUM_ROIS, NUM_CATEGORY, NUM_POINT_PER_ROI]
    class_ids = tf.cast(tf.reshape(class_ids, [-1]), tf.int32)
    class_ids_aug = tf.stack([tf.range(batch_size*num_rois, dtype=tf.int32), class_ids], 1)
    rpointnet_masks = tf.gather_nd(rpointnet_masks, class_ids_aug) #[-1, NUM_POINT_PER_ROI]

    # [B, NUM_ROIS, NUM_POINT, NUM_POINT_PER_ROI]
    dist = tf.reduce_sum(tf.square(tf.expand_dims(tf.expand_dims(pc, 1),3)-tf.expand_dims(pc_coord_cropped,2)), -1)
    min_idx = tf.argmin(dist, 3, output_type=tf.int32) # [B, NUM_ROIS, NUM_POINT]
    min_idx = tf.reshape(min_idx, [-1, num_point]) # [-1, NUM_POINT]
    min_idx_aug = tf.tile(tf.expand_dims(tf.range(batch_size*num_rois, dtype=tf.int32),-1), [1, num_point])
    min_idx_aug = tf.stack([tf.reshape(min_idx_aug, [-1]), tf.reshape(min_idx, [-1])], 1)
    rpointnet_mask_unmolded = tf.reshape(tf.gather_nd(rpointnet_masks, min_idx_aug), [batch_size, num_rois, num_point])

    # Mask out regions outside rois
    pc_aug = tf.expand_dims(pc, 1) # [B, 1, NUM_POINT, 3]
    rois_aug = tf.expand_dims(rois, 2) # [B, NUM_ROIS, 1, 6]
    roi_masks = tf.logical_and(pc_aug>=(rois_aug[:,:,:,:3]-rois_aug[:,:,:,3:]/2),
                               pc_aug<=(rois_aug[:,:,:,:3]+rois_aug[:,:,:,3:]/2))
    roi_masks = tf.logical_and(tf.logical_and(roi_masks[:,:,:,0], roi_masks[:,:,:,1]), roi_masks[:,:,:,2]) # [B, NUM_ROIS, NUM_POINT]
    roi_masks = tf.cast(roi_masks, tf.float32)

    rpointnet_mask_unmolded = tf.multiply(rpointnet_mask_unmolded, roi_masks)
    return rpointnet_mask_unmolded


def rpointnet(pc, color, pc_ins, group_label, group_indicator, seg_label, bbox_ins, config, is_training, mode='training', bn_decay=None):
    ''' Shape proposal generation
    Inputs:
        pc: [B, NUM_POINT, 3]
        color: [B, NUM_POINT, 3]
        pc_ins: [B, NUM_GROUP, NUM_POINT_INS, 3], in world coord sys
        group_label: [B, NUM_POINT]
        group_indicator: [B, NUM_GROUP]
        seg_label: [B, NUM_POINT]
        bbox_ins: [B, NUM_GROUP, 6]
    Returns:
        
    '''
    assert mode in ['training', 'inference']
    if not config.USE_COLOR:
        color = None
    if 'SPN' in config.TRAIN_MODULE and mode=='training':
        end_points = shape_proposal_net(pc, color, pc_ins, group_label, group_indicator, config.NUM_CATEGORY, scope='shape_proposal_net', is_training=is_training, bn_decay=bn_decay, nsmp=config.NUM_SAMPLE, return_fullfea=False, mode=mode)
    else:
        end_points = shape_proposal_net(pc, color, pc_ins, group_label, group_indicator, config.NUM_CATEGORY, scope='shape_proposal_net', is_training=tf.constant(False), bn_decay=None, nsmp=config.NUM_SAMPLE, return_fullfea=True, mode=mode)
        end_points = dict_stop_gradient(end_points)
    if config.SHRINK_BOX:
        end_points['bbox_ins_pred'] = box_shrink(end_points['bbox_ins_pred'], pc)
    group_label_onehot = tf.one_hot(group_label, depth=config.NUM_GROUP, axis=-1) #[B, NUM_POINT, NUM_GROUP]
    seg_label_per_group = tf.multiply(tf.cast(tf.expand_dims(seg_label,-1), tf.float32), group_label_onehot)
    seg_label_per_group = tf.cast(tf.round(tf.divide(tf.reduce_sum(seg_label_per_group, 1),tf.reduce_sum(group_label_onehot, 1)+1e-8)), tf.int32) #[B, NUM_GROUP]

    if 'RPOINTNET' in config.TRAIN_MODULE or mode=='inference':
        SPN_NMS_MAX_SIZE = config.SPN_NMS_MAX_SIZE_TRAINING if mode == "training"\
            else config.SPN_NMS_MAX_SIZE_INFERENCE
        # 3D non maximum suppression - selected_indices: [B, M], spn_rois: [B, M, 6]
        selected_indices = tf.stop_gradient(tf.py_func(nms_3d, [end_points['bbox_ins_pred'], end_points['fb_prob'][:,:,1], config.SPN_PRE_NMS_LIMIT, SPN_NMS_MAX_SIZE, config.SPN_IOU_THRESHOLD, config.SPN_SCORE_THRESHOLD], tf.int32))
        spn_rois = gather_selection(end_points['bbox_ins_pred'], selected_indices, SPN_NMS_MAX_SIZE)

        if mode=='training':
            # Detection target generation - rois: [B, TRAIN_ROIS_PER_IMAGE, 6], target_class_ids: [B, TRAIN_ROIS_PER_IMAGE]
            # target_bbox: [B, TRAIN_ROIS_PER_IMAGE, 6], target_mask_selection_idx: [B, TRAIN_ROIS_PER_IMAGE, NUM_POINT_INS_MASK]
            # target_mask: [B, TRAIN_ROIS_PER_IMAGE, NUM_POINT_INS_MASK], all zero padded
            names = ["rois", "target_class_ids", "target_bbox", "target_mask_selection_idx", "target_mask"]
            outputs = batch_slice(
                [spn_rois, seg_label_per_group, bbox_ins, group_label_onehot, pc],
                lambda v, w, x, y, z: detection_target_gen(v, w, x, y, z, config),
                config.BATCH_SIZE, names=names)
            rois, target_class_ids, target_bbox, target_mask_selection_idx, target_mask = outputs

            # Points cropping - pc_fea_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, NFEA]
            # pc_center_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, 3]
            # pc_coord_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, 3]
            ##### sem fpn fea
            sem_fea_full_l1 = tf_util.conv1d(end_points['sem_fea_full_l1'], 64, 1, padding='VALID', bn=True, is_training=is_training, scope='fpn1', bn_decay=bn_decay)
            sem_fea_full_l2 = tf_util.conv1d(end_points['sem_fea_full_l2'], 64, 1, padding='VALID', bn=True, is_training=is_training, scope='fpn2', bn_decay=bn_decay)
            sem_fea_full_l3 = tf_util.conv1d(end_points['sem_fea_full_l3'], 64, 1, padding='VALID', bn=True, is_training=is_training, scope='fpn3', bn_decay=bn_decay)
            sem_fea_full_l4 = tf_util.conv1d(end_points['sem_fea_full_l4'], 64, 1, padding='VALID', bn=True, is_training=is_training, scope='fpn4', bn_decay=bn_decay)
            pc_fea_cropped, pc_center_cropped, pc_coord_cropped, _ = points_cropping(pc, tf.concat((end_points['entity_fea'], sem_fea_full_l1, sem_fea_full_l2, sem_fea_full_l3, sem_fea_full_l4), -1), 
                end_points['center_pos'], rois, target_mask_selection_idx, config.TRAIN_ROIS_PER_IMAGE, config.NUM_POINT_INS_MASK, config.NORMALIZE_CROP_REGION)

            # Classification and bbox refinement head
            rpointnet_class_logits, rpointnet_class, rpointnet_bbox = classification_head(pc_coord_cropped, 
                tf.concat((pc_fea_cropped, pc_center_cropped), -1), config.NUM_CATEGORY, 
                [128, 256, 512], [256, 256], is_training, bn_decay, 'classification_head')

            # Mask prediction head
            rpointnet_mask = segmentation_head(pc_coord_cropped,
                tf.concat((pc_fea_cropped, pc_center_cropped), -1), config.NUM_CATEGORY,
                [64, 64], [64, 128, 512], [256, 256], is_training, bn_decay, 'segmentation_head')
        elif mode=='inference':
            # rois: [B, NUM_ROIS, 6]
            names = ["rois", "mask_selection_idx"]
            outputs = batch_slice(
                [spn_rois, pc],
                lambda x, y: mask_selection_gen(x, y, SPN_NMS_MAX_SIZE, config, empty_removal=True),
                config.BATCH_SIZE, names=names)
            rois, mask_selection_idx = outputs

            # Points cropping - pc_fea_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, NFEA]
            # pc_center_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, 3]
            # pc_coord_cropped: [B, NUM_ROIS, NUM_POINT_PER_ROI, 3]
            ##### sem fpn fea
            sem_fea_full_l1 = tf_util.conv1d(end_points['sem_fea_full_l1'], 64, 1, padding='VALID', bn=True, is_training=is_training, scope='fpn1', bn_decay=bn_decay)
            sem_fea_full_l2 = tf_util.conv1d(end_points['sem_fea_full_l2'], 64, 1, padding='VALID', bn=True, is_training=is_training, scope='fpn2', bn_decay=bn_decay)
            sem_fea_full_l3 = tf_util.conv1d(end_points['sem_fea_full_l3'], 64, 1, padding='VALID', bn=True, is_training=is_training, scope='fpn3', bn_decay=bn_decay)
            sem_fea_full_l4 = tf_util.conv1d(end_points['sem_fea_full_l4'], 64, 1, padding='VALID', bn=True, is_training=is_training, scope='fpn4', bn_decay=bn_decay)

            #### generate fb_conf and sem_conf
            # pc: [B, N, 3], fb_prob: [B, nsmp, 2], pc_seed: [B, nsmp, 3] -> fb_prob: [B, N]
            midx = tf.argmin(tf.reduce_sum(tf.square(tf.expand_dims(pc, 2)-tf.expand_dims(end_points['pc_seed'],1)),-1), 2)
            midx_aug = tf.tile(tf.reshape(tf.range(config.BATCH_SIZE, dtype=tf.int64),[-1,1]), [1,config.NUM_POINT])
            midx_aug = tf.stack((tf.reshape(midx_aug, [-1]), tf.reshape(midx, [-1])), 1)
            fb_prob = tf.reshape(tf.gather_nd(end_points['fb_prob'], midx_aug), [config.BATCH_SIZE, config.NUM_POINT, 2])
            fb_prob = fb_prob[:,:,1]
            sem_prob = tf.nn.softmax(end_points['sem_class_logits'], -1) #[B, NUM_POINT, NUM_CATEGORY]

            ##### sem fpn fea
            pc_fea_cropped, pc_center_cropped, pc_coord_cropped, pc_coord_cropped_unnormalized = points_cropping(pc, tf.concat((end_points['entity_fea'], sem_fea_full_l1, sem_fea_full_l2, sem_fea_full_l3, sem_fea_full_l4, tf.expand_dims(fb_prob, -1), sem_prob), -1), 
                end_points['center_pos'], rois, mask_selection_idx, SPN_NMS_MAX_SIZE, config.NUM_POINT_INS_MASK, config.NORMALIZE_CROP_REGION)
            pc_fea_cropped, fb_prob_cropped, sem_prob_cropped = tf.split(pc_fea_cropped, [1024, 1, config.NUM_CATEGORY], -1)
            
            fb_prob_cropped = tf.reduce_mean(tf.squeeze(fb_prob_cropped, -1),-1) # [B, NUM_ROIS]
            end_points['fb_prob_cropped'] = fb_prob_cropped
            sem_prob_cropped = tf.reduce_mean(sem_prob_cropped, 2) # [B, NUM_ROIS, NUM_CATEGORY]

            # Classification and bbox refinement head - rpointnet_class_logits: [B, NUM_ROIS, NUM_CATEGORY]
            # rpointnet_class: [B, NUM_ROIS, NUM_CATEGORY],
            # rpointnet_bbox: [B, NUM_ROIS, NUM_CATEGORY, 6]
            rpointnet_class_logits, rpointnet_class, rpointnet_bbox = classification_head(pc_coord_cropped, 
                tf.concat((pc_fea_cropped, pc_center_cropped), -1), config.NUM_CATEGORY, 
                [128, 256, 512], [256, 256], is_training, bn_decay, 'classification_head')

            midx = tf.argmax(rpointnet_class_logits, -1) # [B, NUM_ROIS]
            midx = tf.stack((tf.range(config.BATCH_SIZE*SPN_NMS_MAX_SIZE, dtype=tf.int64), tf.reshape(midx,[-1])), 1)
            sem_prob_cropped = tf.gather_nd(tf.reshape(sem_prob_cropped, [-1, config.NUM_CATEGORY]), midx)
            sem_prob_cropped = tf.reshape(sem_prob_cropped, [config.BATCH_SIZE, SPN_NMS_MAX_SIZE])
            end_points['sem_prob_cropped'] = sem_prob_cropped

            # Generate detections: [B, DETECTION_MAX_INSTANCES, (center_x, center_y, center_z, l, w, h, class_id, score)]
            detections = batch_slice(
                [rois, rpointnet_class, rpointnet_bbox, pc, fb_prob_cropped, sem_prob_cropped],
                lambda u, v, x, y, w, z: refine_detections(u, v, x, y, w, z, config),
                config.BATCH_SIZE)

            # Re-crop point cloud for mask prediction
            names = ["rois_final", "mask_selection_idx_final"]
            outputs = batch_slice(
                [detections[:,:,:6], pc],
                lambda x, y: mask_selection_gen(x, y, config.DETECTION_MAX_INSTANCES, config, empty_removal=False),
                config.BATCH_SIZE, names=names)
            rois_final, mask_selection_idx_final = outputs
            # Points cropping - pc_fea_cropped_final: [B, DETECTION_MAX_INSTANCES, NUM_POINT_PER_ROI, NFEA]
            # pc_center_cropped_final: [B, DETECTION_MAX_INSTANCES, NUM_POINT_PER_ROI, 3]
            # pc_coord_cropped_final: [B, DETECTION_MAX_INSTANCES, NUM_POINT_PER_ROI, 3]
            ##### sem fpn fea
            pc_fea_cropped_final, pc_center_cropped_final, pc_coord_cropped_final, pc_coord_cropped_final_unnormalized = points_cropping(pc, tf.concat((end_points['entity_fea'], sem_fea_full_l1, sem_fea_full_l2, sem_fea_full_l3, sem_fea_full_l4), -1), 
                end_points['center_pos'], rois_final, mask_selection_idx_final, config.DETECTION_MAX_INSTANCES, config.NUM_POINT_INS_MASK, config.NORMALIZE_CROP_REGION)

            # Mask prediction head
            rpointnet_mask = segmentation_head(pc_coord_cropped_final,
                tf.concat((pc_fea_cropped_final, pc_center_cropped_final), -1), config.NUM_CATEGORY,
                [64, 64], [64, 128, 512], [256, 256], is_training, bn_decay, 'segmentation_head')

            # Unmold segmentation
            rpointnet_mask_selected = select_segmentation(tf.nn.sigmoid(rpointnet_mask), detections[:,:,6])

    # Update end_points
    end_points['group_label'] = group_label
    end_points['seg_label'] = seg_label
    end_points['seg_label_per_group'] = seg_label_per_group #[B, NUM_GROUP]
    end_points['bbox_ins'] = bbox_ins #[B, NUM_GROUP, 6]
    if 'RPOINTNET' in config.TRAIN_MODULE and mode=='training':
        end_points['selected_indices'] = selected_indices #[B, SPN_NMS_MAX_SIZE]
        end_points['spn_rois'] = spn_rois #[B, SPN_NMS_MAX_SIZE, 6]
        end_points['rois'] = rois #[B, NUM_ROIS, 6]
        end_points['target_class_ids'] = target_class_ids #[B, NUM_ROIS]
        end_points['target_bbox'] = target_bbox #[B, NUM_ROIS, 6]
        end_points['target_mask_selection_idx'] = target_mask_selection_idx #[B, NUM_ROIS, NUM_POINT_PER_ROI]
        end_points['target_mask'] = target_mask #[B, NUM_ROIS, NUM_POINT_PER_ROI]
        end_points['rpointnet_class_logits'] = rpointnet_class_logits #[B, NUM_ROIS, NUM_CATEGORY]
        end_points['rpointnet_class'] = rpointnet_class #[B, NUM_ROIS, NUM_CATEGORY]
        end_points['rpointnet_bbox'] = rpointnet_bbox #[B, NUM_ROIS, NUM_CATEGORY, 6]
        end_points['rpointnet_mask'] = rpointnet_mask #[B, NUM_ROIS, NUM_POINT_PER_ROI, NUM_CATEGORY]
    elif mode=='inference':
        end_points['selected_indices'] = selected_indices #[B, SPN_NMS_MAX_SIZE]
        end_points['spn_rois'] = spn_rois #[B, SPN_NMS_MAX_SIZE, 6]
        end_points['rois'] = rois #[B, NUM_ROIS, 6]
        end_points['rpointnet_class_logits'] = rpointnet_class_logits #[B, NUM_ROIS, NUM_CATEGORY]
        end_points['rpointnet_class'] = rpointnet_class #[B, NUM_ROIS, NUM_CATEGORY]
        end_points['rpointnet_bbox'] = rpointnet_bbox #[B, NUM_ROIS, NUM_CATEGORY, 6]
        end_points['detections'] = detections #[B, DETECTION_MAX_INSTANCES, 6+2]
        end_points['rpointnet_mask'] = rpointnet_mask #[B, DETECTION_MAX_INSTANCES, NUM_POINT_PER_ROI, NUM_CATEGORY]
        end_points['rpointnet_mask_selected'] = rpointnet_mask_selected #[B, DETECTION_MAX_INSTANCES, NUM_POINT_PER_ROI]
        end_points['pc_coord_cropped_final_unnormalized'] = pc_coord_cropped_final_unnormalized #[B, DETECTION_MAX_INSTANCES, NUM_POINT_PER_ROI, 3]
    return end_points

def smooth_l1_loss(y_true, y_pred):
    """Implements Smooth-L1 loss.
    y_true and y_pred are typically: [N, 6], but could be any shape.
    """
    diff = tf.abs(y_true - y_pred)
    less_than_one = tf.cast(tf.less(diff, 1.0), tf.float32)
    loss = (less_than_one * 0.5 * diff**2) + (1 - less_than_one) * (diff - 0.5)
    return loss

def get_spn_class_loss(fb_logits, spn_match):
    '''
    Inputs:
       fb_logits: [B, nsmp, 2]
       spn_match: [B, nsmp]
    '''
    # Only postive and negative contribute to loss but not neutral
    fb_logits = tf.reshape(fb_logits, [-1, 2])
    spn_match = tf.reshape(spn_match, [-1])
    valid_mask = tf.not_equal(spn_match, 0)
    fb_logits = tf.boolean_mask(fb_logits, valid_mask)
    spn_match = tf.cast(tf.equal(spn_match, 1), tf.int32)
    spn_match = tf.boolean_mask(spn_match, valid_mask)
    loss = tf.cond(tf.size(spn_match)>0,
        lambda: tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=spn_match, logits=fb_logits)),
        lambda: tf.constant(0.0))

    return loss

def get_rpointnet_class_loss(rpointnet_class_logits, gt_class_ids, roi_valid_mask):
    '''
    Inputs:
       rpointnet_class_logits: [B, NUM_ROIS, NUM_CATEGORY]
       gt_class_ids: [B, NUM_ROIS], zero padded
       roi_valid_mask: [B, NUM_ROIS]
    '''
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=gt_class_ids, logits=rpointnet_class_logits)
    loss = tf.multiply(loss, roi_valid_mask)
    loss = tf.divide(tf.reduce_sum(loss), tf.reduce_sum(roi_valid_mask)+1e-8)

    return loss

def get_rpointnet_bbox_loss(gt_bbox, gt_class_ids, pred_bbox, roi_valid_mask, num_category):
    '''
    Inputs:
       gt_bbox: [B, NUM_ROIS, 6]
       gt_class_ids: [B, NUM_ROIS], zero padded
       pred_bbox: [B, NUM_ROIS, NUM_CATEGORY, 6]
       roi_valid_mask: [B, NUM_ROIS]
       num_category: scalar
    '''
    # Only foreground box contribute to the loss
    gt_bbox = tf.reshape(gt_bbox, [-1, 6])
    gt_class_ids = tf.reshape(gt_class_ids, [-1])
    pred_bbox = tf.reshape(pred_bbox, [-1, num_category, 6])
    roi_valid_mask = tf.reshape(roi_valid_mask, [-1])

    gt_selected_indices = tf.where(tf.logical_and(tf.greater(roi_valid_mask, 0), tf.greater(gt_class_ids, 0)))[:,0]
    gt_selected_indices = tf.cast(gt_selected_indices, tf.int32)
    pred_selected_indices = tf.concat((tf.reshape(gt_selected_indices, [-1,1]),
        tf.reshape(tf.gather(gt_class_ids, gt_selected_indices), [-1,1])), axis=1)

    gt_bbox = tf.gather(gt_bbox, gt_selected_indices, axis=0)
    pred_bbox = tf.gather_nd(pred_bbox, pred_selected_indices)

    loss = tf.cond(tf.size(gt_bbox)>0,
        lambda: tf.reduce_mean(tf.reduce_sum(smooth_l1_loss(y_true=gt_bbox, y_pred=pred_bbox),1),0),
        lambda: tf.constant(0.0))

    return loss

def get_rpointnet_mask_loss(gt_masks, gt_class_ids, pred_masks, roi_valid_mask, num_category, num_point_per_roi):
    '''
    Inputs:
       gt_masks: [B, NUM_ROIS, NUM_POINT_PER_ROI]
       gt_class_ids: [B, NUM_ROIS], zero padded
       pred_masks: [B, NUM_ROIS, NUM_POINT_PER_ROI, NUM_CATEGORY]
       roi_valid_mask: [B, NUM_ROIS]
       num_category: scalar
       num_point_per_roi: scalar
    '''
    # Only foreground box contribute to the loss
    gt_masks = tf.reshape(gt_masks, [-1, num_point_per_roi])
    gt_class_ids = tf.reshape(gt_class_ids, [-1])
    pred_masks = tf.reshape(pred_masks, [-1, num_point_per_roi, num_category])
    pred_masks = tf.transpose(pred_masks, perm=[0,2,1])
    roi_valid_mask = tf.reshape(roi_valid_mask, [-1])

    gt_selected_indices = tf.where(tf.logical_and(tf.greater(roi_valid_mask, 0), tf.greater(gt_class_ids, 0)))[:,0]
    gt_selected_indices = tf.cast(gt_selected_indices, tf.int32)
    pred_selected_indices = tf.concat((tf.reshape(gt_selected_indices, [-1,1]),
        tf.reshape(tf.gather(gt_class_ids, gt_selected_indices), [-1,1])), axis=1)

    gt_masks = tf.gather(gt_masks, gt_selected_indices, axis=0) # [N, NUM_POINT_PER_ROI]
    gt_masks = tf.cast(gt_masks, tf.float32)
    pred_masks = tf.gather_nd(pred_masks, pred_selected_indices) # [N, NUM_POINT_PER_ROI]

    loss = tf.cond(tf.size(gt_masks)>0,
        lambda: tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=gt_masks, logits=pred_masks)),
        lambda: tf.constant(0.0))

    return loss

def get_loss(end_points, config, alpha, smpw, mode='training'):
    batch_size = config.BATCH_SIZE
    nsmp_ins = config.NUM_POINT_INS
    bbox_size = tf.reduce_max(end_points['pc_ins_centered_seed'], axis=2, keep_dims=True)-tf.reduce_min(end_points['pc_ins_centered_seed'], axis=2, keep_dims=True)
    radius = 1e-8 + tf.sqrt(tf.reduce_sum(tf.square(bbox_size/2), axis=-1, keep_dims=True)) # [B, nsmp, 1, 1]
    shift_gt_seed = end_points['pc_ins_center_seed']-tf.expand_dims(end_points['pc_seed'],2) # [B, nsmp, 1, 3]
    shift_dist = tf.sqrt(tf.reduce_sum(tf.square(shift_gt_seed), 3, keep_dims=True)+1e-8)
    shift_gt_seed_normalized_4d = tf.concat((tf.divide(shift_gt_seed, shift_dist), tf.divide(shift_dist, radius)), -1)
    shift_pred_seed_normalized_4d = tf.concat((tf.expand_dims(end_points['shift_pred_seed_4d'],2)[:,:,:,:3], tf.divide(tf.expand_dims(end_points['shift_pred_seed_4d'],2)[:,:,:,3:], radius)), -1)

    # Fg/Bg loss, spn_match: [B, nsmp]
    fb_score_gt = tf.squeeze(gather_selection(tf.expand_dims(end_points['seg_label'],-1), end_points['ind_seed'], end_points['ind_seed'].get_shape()[1].value),-1)
    fb_score_gt = tf.cast(tf.greater(fb_score_gt,0), tf.float32) # [B, nsmp]
    spn_match = batch_slice(
        [end_points['bbox_ins_pred'], fb_score_gt, tf.cast(tf.greater(end_points['seg_label_per_group'],0),tf.float32), end_points['bbox_ins']],
        lambda w, x, y, z: spn_target_gen(w, x, y, z),
        batch_size, names=["spn_match"])
    spn_match = tf.stop_gradient(tf.cast(spn_match, tf.int32))
    end_points['spn_match'] = spn_match
    spn_class_loss = get_spn_class_loss(end_points['fb_logits'], spn_match)

    # # Reconstruction loss
    pc_ins_pred = end_points['pc_ins_pred']
    pc_ins_pred_normalized = tf.reshape(tf.div(pc_ins_pred, radius), [-1, nsmp_ins, 3]) # [B*nsmp, nsmp_ins, 3]
    pc_ins_gt_normalized = tf.reshape(tf.div(end_points['pc_ins_centered_seed']+shift_gt_seed, radius), [-1, nsmp_ins, 3]) # [B*nsmp, nsmp_ins, 3]
    recon_valid_mask = tf.reshape(fb_score_gt, [-1])
    recon_valid_mask = tf.stop_gradient(recon_valid_mask)
    dists_forward,_,dists_backward,_ = tf_nndistance.nn_distance(pc_ins_pred_normalized, pc_ins_gt_normalized)
    recons_loss = tf.reduce_mean(dists_forward+dists_backward, axis=-1) # B*nsmp
    recons_loss = tf.divide(tf.reduce_sum(tf.multiply(recons_loss, recon_valid_mask)),
        tf.reduce_sum(recon_valid_mask)+1e-8)

    # Shift loss
    shift_loss = tf.reduce_sum(smooth_l1_loss(shift_gt_seed_normalized_4d, shift_pred_seed_normalized_4d), axis=-1)
    shift_loss = tf.reduce_sum(tf.multiply(tf.reshape(shift_loss, [-1]), recon_valid_mask))
    shift_loss = tf.divide(shift_loss, tf.reduce_sum(recon_valid_mask)+1e-8)

    # Sem loss
    ind_sem = end_points['ind_sem']
    nsmp_sem = ind_sem.get_shape()[1].value
    sem_labels = end_points['seg_label']
    ind_sem_aug = tf.tile(tf.reshape(tf.range(batch_size),[-1, 1]), [1, nsmp_sem])
    ind_sem_aug = tf.concat( (tf.reshape(ind_sem_aug, [-1, 1]), tf.reshape(ind_sem, [-1, 1])), 1 )
    sem_labels = tf.reshape(tf.gather_nd(sem_labels, ind_sem_aug), [batch_size, nsmp_sem])
    smpw = tf.reshape(tf.gather_nd(smpw, ind_sem_aug), [batch_size, nsmp_sem])
    sem_labels = tf.cast(sem_labels, tf.int32)
    sem_loss = tf.losses.sparse_softmax_cross_entropy(labels=sem_labels, logits=end_points['sem_class_logits'], weights=smpw)
    end_points['sem_labels'] = sem_labels

    # KL loss
    mean = end_points['mean'] # [B, nsmp, 256]
    log_var = end_points['log_var']
    cmean = end_points['cmean']
    clog_var = end_points['clog_var']
    kl_loss = 0.5 * tf.reduce_mean( log_var - clog_var + (tf.exp(clog_var) + (mean-cmean)**2)/tf.exp(log_var) - 1.0, 2) 
    kl_loss = tf.divide(tf.reduce_sum(tf.multiply(tf.reshape(kl_loss, [-1]), recon_valid_mask)),
        tf.reduce_sum(recon_valid_mask)+1e-8)

    if 'RPOINTNET' in config.TRAIN_MODULE and mode=='training':
        # rpointnet classification loss
        roi_valid_mask = tf.cast(tf.not_equal(tf.reduce_sum(tf.abs(end_points['rois']), axis=-1),0), tf.float32)
        rpointnet_class_loss = get_rpointnet_class_loss(end_points['rpointnet_class_logits'], end_points['target_class_ids'], roi_valid_mask)

        # rpointnet bbox loss
        rpointnet_bbox_loss = get_rpointnet_bbox_loss(end_points['target_bbox'], end_points['target_class_ids'], end_points['rpointnet_bbox'], roi_valid_mask, config.NUM_CATEGORY)

        # rpointnet mask loss
        rpointnet_mask_loss = get_rpointnet_mask_loss(end_points['target_mask'], end_points['target_class_ids'], end_points['rpointnet_mask'], roi_valid_mask, config.NUM_CATEGORY, config.NUM_POINT_INS_MASK)
    
    if 'SPN' in config.TRAIN_MODULE:
        loss = kl_loss * alpha + recons_loss + shift_loss + spn_class_loss + sem_loss
        if 'RPOINTNET' in config.TRAIN_MODULE and mode=='training':
            loss += rpointnet_class_loss + rpointnet_bbox_loss + rpointnet_mask_loss
    elif 'RPOINTNET' in config.TRAIN_MODULE and mode=='training':
        loss = rpointnet_class_loss + rpointnet_bbox_loss + rpointnet_mask_loss
    else:
        loss = tf.constant(0.0)


    # Store end_points
    end_points['spn_class_loss'] = spn_class_loss
    end_points['recons_loss'] = recons_loss
    end_points['shift_loss'] = shift_loss
    end_points['sem_loss'] = sem_loss
    end_points['kl_loss'] = kl_loss
    if 'RPOINTNET' in config.TRAIN_MODULE and mode=='training':
        end_points['rpointnet_class_loss'] = rpointnet_class_loss
        end_points['rpointnet_bbox_loss'] = rpointnet_bbox_loss
        end_points['rpointnet_mask_loss'] = rpointnet_mask_loss
    end_points['loss'] = loss
    
    return loss, end_points


if __name__=='__main__':
    myconfig = config.Config()
    with tf.Graph().as_default():
        pc_pl, pc_ins_pl, group_label_pl, group_indicator_pl, seg_label_pl, bbox_ins_pl = placeholder_inputs(myconfig)
        end_points = rpointnet(pc_pl, pc_ins_pl, group_label_pl, group_indicator_pl, seg_label_pl, bbox_ins_pl, myconfig, tf.constant(True), bn_decay=None)
        loss, end_points = get_loss(end_points, myconfig, 1.0)
        print(end_points)