package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.BisectingKMeansModelData;
import com.alibaba.alink.operator.common.clustering.BisectingKMeansModelData.ClusterSummary;
import com.alibaba.alink.operator.common.clustering.BisectingKMeansModelDataConverter;
import com.alibaba.alink.operator.common.clustering.DistanceType;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.distance.EuclideanDistance;
import com.alibaba.alink.operator.common.statistics.StatisticsHelper;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.params.clustering.BisectingKMeansTrainParams;
import org.apache.flink.api.common.functions.*;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
import java.util.HashSet;
import java.util.HashMap;
import java.util.Set;
import java.util.Random;
import java.util.Comparator;

/**
 * Bisecting k-means is a kind of hierarchical clustering algorithm.
 * <p>
 * Bisecting k-means algorithm starts from a single cluster that contains all points. Iteratively it finds divisible
 * clusters on the bottom level and bisects each of them using k-means, until there are `k` leaf clusters in total or no
 * leaf clusters are divisible.
 *
 * @see <a href="http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf">
 * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, KDD Workshop on Text Mining,
 * 2000.</a>
 */
public final class BisectingKMeansTrainBatchOp extends BatchOperator<BisectingKMeansTrainBatchOp>
    implements BisectingKMeansTrainParams<BisectingKMeansTrainBatchOp> {

    public final static long ROOT_INDEX = 1;
    private static final Logger LOG = LoggerFactory.getLogger(BisectingKMeansTrainBatchOp.class);
    private static final String VECTOR_SIZE = "vectorSize";
    private static final String DIVISIBLE_INDICES = "divisibleIndices";
    private static final String ITER_INFO = "iterInfo";
    private static final String NEW_CLUSTER_CENTERS = "newClusterCenters";

    public BisectingKMeansTrainBatchOp() {
        this(new Params());
    }

    public BisectingKMeansTrainBatchOp(Params params) {
        super(params);
    }

    /**
     * Returns the left child index of the given node index.
     */
    public static long leftChildIndex(long index) {
        return 2 * index;
    }

    /**
     * Returns the right child index of the given node index.
     */
    public static long rightChildIndex(long index) {
        return 2 * index + 1;
    }

    private static DataSet<Tuple2<Long, DenseVector>>
    getNewClusterCenters(DataSet<Tuple3<Long, ClusterSummary, IterInfo>> allClusterSummaries) {
        return allClusterSummaries
            .flatMap(new FlatMapFunction<Tuple3<Long, ClusterSummary, IterInfo>, Tuple2<Long, DenseVector>>() {
                @Override
                public void flatMap(Tuple3<Long, ClusterSummary, IterInfo> value,
                                    Collector<Tuple2<Long, DenseVector>> out) {
                    if (value.f2.isNew) {
                        out.collect(Tuple2.of(value.f0, value.f1.center));
                    }
                }
            })
            .name("getNewClusterCenters");
    }

    private static DataSet<Long>
    getDivisibleClusterIndices(
        DataSet<Tuple3<Long, ClusterSummary, IterInfo>> allClusterSummaries) {
        return allClusterSummaries
            .flatMap(new FlatMapFunction<Tuple3<Long, ClusterSummary, IterInfo>, Long>() {
                @Override
                public void flatMap(Tuple3<Long, ClusterSummary, IterInfo> value,
                                    Collector<Long> out) {
                    LOG.info("getDivisibleS {}", value);
                    if (value.f2.isDividing) {
                        out.collect(value.f0);
                    }
                }
            })
            .name("getDivisibleClusterIndices");
    }

    /**
     * If at the cluster dividing step, divide current active clusters; Otherwise, just copy existing clusters.
     *
     * @param clustersSummariesAndIterInfo clusterId, clusterSummary, IterInfo
     * @param k                            cluster number
     * @return original clusters and new clusters.
     */
    private static DataSet<Tuple3<Long, ClusterSummary, IterInfo>> getOrSplitClusters(
        DataSet<Tuple3<Long, ClusterSummary, IterInfo>> clustersSummariesAndIterInfo,
        final int k,
        final int minDivisibleClusterSize) {
        return clustersSummariesAndIterInfo
            .partitionCustom(
                new Partitioner<Integer>() {
                    @Override
                    public int partition(Integer key, int numPartitions) {
                        return 0;
                    }
                }, new KeySelector<Tuple3<Long, ClusterSummary, IterInfo>, Integer>() {
                    @Override
                    public Integer getKey(Tuple3<Long, ClusterSummary, IterInfo> value) {
                        return 0;
                    }
                })
            .mapPartition(
                new RichMapPartitionFunction<Tuple3<Long, ClusterSummary, IterInfo>, Tuple3
                    <Long, ClusterSummary, IterInfo>>() {
                    private transient Random random;

                    @Override
                    public void open(Configuration parameters) {
                        if (random == null && getRuntimeContext().getIndexOfThisSubtask() == 0) {
                            random = new Random(getRuntimeContext().getIndexOfThisSubtask());
                        }
                    }

                    @Override
                    public void mapPartition(
                        Iterable<Tuple3<Long, ClusterSummary, IterInfo>> summaries,
                        Collector<Tuple3<Long, ClusterSummary, IterInfo>> out) {
                        if (getRuntimeContext().getIndexOfThisSubtask() > 0) {
                            return;
                        }

                        List<Tuple3<Long, ClusterSummary, IterInfo>> clustersAndIterInfo = new ArrayList<>();
                        summaries.forEach(clustersAndIterInfo::add);

                        //At the first step of bisecting
                        if (clustersAndIterInfo.get(0).f2.doBisectionInStep()) {
                            // find all splitable clusters
                            Set<Long> splitableClusters = findSplitableClusters(clustersAndIterInfo, k,
                                minDivisibleClusterSize);
                            boolean shouldStopSplit = (splitableClusters.size() + getNumLeaf(clustersAndIterInfo)) >= k;

                            // split clusters
                            clustersAndIterInfo.forEach(t -> {
                                assert (!t.f2.isDividing);
                                assert (!t.f2.isNew);
                                t.f2.shouldStopSplit = shouldStopSplit;
                                if (splitableClusters.contains(t.f0)) {
                                    ClusterSummary summary = t.f1;
                                    IterInfo newCenterIterInfo = new IterInfo(t.f2.maxIter, t.f2.bisectingStepNo,
                                        t.f2.innerIterStepNo, false, true, shouldStopSplit);
                                    Tuple2<DenseVector, DenseVector> newCenters = initialSplitCenter(summary.center, random);
                                    ClusterSummary leftChildSummary = new ClusterSummary();
                                    leftChildSummary.center = newCenters.f0;
                                    ClusterSummary rightChildSummary = new ClusterSummary();
                                    rightChildSummary.center = newCenters.f1;
                                    t.f2.isDividing = true;
                                    out.collect(t);
                                    out.collect(Tuple3.of(leftChildIndex(t.f0), leftChildSummary, newCenterIterInfo));
                                    out.collect(Tuple3.of(rightChildIndex(t.f0), rightChildSummary, newCenterIterInfo));
                                } else {
                                    out.collect(t);
                                }
                            });
                        } else {
                            // copy existing clusters
                            clustersAndIterInfo.forEach(out::collect);
                        }
                    }
                })
            .name("get_or_split_clusters");
    }

    private static Tuple2<DenseVector, DenseVector> initialSplitCenter(DenseVector center, Random random) {
        int dim = center.size();
        double norm = Math.sqrt(BLAS.dot(center, center));
        double level = 1.0e-4 * norm;
        DenseVector noise = new DenseVector(dim);
        for (int i = 0; i < dim; i++) {
            noise.set(i, level * random.nextDouble());
        }
        return Tuple2.of(center.minus(noise), center.plus(noise));
    }

    private static Set<Long> findSplitableClusters(List<Tuple3<Long, ClusterSummary, IterInfo>> allClusterSummaries,
                                                   final int k,
                                                   final int minDivisibleClusterSize) {
        Set<Long> clusterIds = new HashSet<>();
        List<Long> leafs = new ArrayList<>();
        List<Tuple3<Long, ClusterSummary, IterInfo>> splitableClusters = new ArrayList<>();
        allClusterSummaries.forEach(t -> clusterIds.add(t.f0));
        LOG.info("existingClusterIds {}", JsonConverter.toJson(clusterIds));
        allClusterSummaries.forEach(t -> {
            boolean isLeaf = isLeaf(clusterIds, t.f0);
            if (isLeaf) {
                leafs.add(t.f0);
            }
            if (isLeaf && t.f1.size > 1 && t.f1.size > minDivisibleClusterSize) {
                splitableClusters.add(t);
            }
        });
        int numClusterToSplit = k - leafs.size();
        List<Long> splitableClusterIds = new ArrayList<>();

        splitableClusters.sort(new Comparator<Tuple3<Long, ClusterSummary, IterInfo>>() {
            @Override
            public int compare(Tuple3<Long, ClusterSummary, IterInfo> o1, Tuple3<Long, ClusterSummary, IterInfo> o2) {
                return -Double.compare(o1.f1.cost, o2.f1.cost);
            }
        });
        for (int i = 0; i < Math.min(numClusterToSplit, splitableClusters.size()); i++) {
            splitableClusterIds.add(splitableClusters.get(i).f0);
        }
        LOG.info("toSplitClusterIds {}", JsonConverter.toJson(splitableClusterIds));
        return new HashSet<>(splitableClusterIds);
    }

    private static int getNumLeaf(List<Tuple3<Long, ClusterSummary, IterInfo>> allClusterSummaries) {
        Set<Long> clusterIds = new HashSet<>();
        allClusterSummaries.forEach(t -> clusterIds.add(t.f0));
        int n = 0;
        for (Tuple3<Long, ClusterSummary, IterInfo> t : allClusterSummaries) {
            if (isLeaf(clusterIds, t.f0)) {
                n++;
            }
        }
        return n;
    }

    private static boolean isLeaf(Set<Long> clusterIds, long clusterId) {
        return !clusterIds.contains(leftChildIndex(clusterId)) && !clusterIds.contains(rightChildIndex(clusterId));
    }

    /**
     * Update the assignment of each samples.
     * <p>
     * Note that we keep the updated assignment of each samples in memory, instead of putting it to a looped dataset.
     *
     * @param data              Initial assignment of each samples.
     * @param divisibleIndices  DivisibleIndex set.
     * @param newClusterCenters New Cluster Centers.
     * @param distance          Distance.
     * @param iterInfo          Iter Info.
     * @return Updated assignment of each samples.
     */
    private static DataSet<Tuple3<Long, DenseVector, Long>> updateAssignment(
        DataSet<Tuple3<Long, DenseVector, Long>> data,
        DataSet<Long> divisibleIndices,
        DataSet<Tuple2<Long, DenseVector>> newClusterCenters,
        final ContinuousDistance distance,
        DataSet<Tuple1<IterInfo>> iterInfo) {

        return data
            .map(new RichMapFunction<Tuple3<Long, DenseVector, Long>,
                Tuple4<Integer, Long, DenseVector, Long>>() {
                private transient int taskId;

                @Override
                public void open(Configuration parameters) {
                    this.taskId = getRuntimeContext().getIndexOfThisSubtask();
                }

                @Override
                public Tuple4<Integer, Long, DenseVector, Long> map(Tuple3<Long, DenseVector, Long> value) {
                    return Tuple4.of(taskId, value.f0, value.f1, value.f2);
                }
            })
            .withForwardedFields("f0->f1;f1->f2;f2->f3")
            .name("append_partition_id")
            .groupBy(0)
            .sortGroup(1, Order.ASCENDING)
            .withPartitioner(new Partitioner<Integer>() {
                @Override
                public int partition(Integer key, int numPartitions) {
                    return key % numPartitions;
                }
            })
            .reduceGroup(new UpdateAssignment(distance))
            .withBroadcastSet(divisibleIndices, DIVISIBLE_INDICES)
            .withBroadcastSet(newClusterCenters, NEW_CLUSTER_CENTERS)
            .withBroadcastSet(iterInfo, ITER_INFO)
            .name("update_assignment");
    }

    static class UpdateAssignment extends RichGroupReduceFunction<Tuple4<Integer, Long, DenseVector, Long>, Tuple3<Long, DenseVector,
        Long>> {
        transient Set<Long> divisibleIndices;
        transient Map<Long, DenseVector> newClusterCenters;

        transient boolean shouldInitState;
        transient boolean shouldUpdateState;
        // sampleId -> clusterId
        transient List<Tuple2<Long, Long>> assignmentInState;

        // In euclidean case, find closer center out of two by checking which
        // side of the middle plane the point lies in.
        transient Map<Long, Tuple2<DenseVector, Double>> middlePlanes;

        ContinuousDistance distance;

        UpdateAssignment(ContinuousDistance distance){
            this.distance = distance;
        }

        @Override
        public void open(Configuration parameters) {
            List<Long> bcDivisibleIndices = getRuntimeContext().getBroadcastVariable(DIVISIBLE_INDICES);
            divisibleIndices = new HashSet<>(bcDivisibleIndices);
            List<Tuple1<IterInfo>> bcIterInfo = getRuntimeContext().getBroadcastVariable(ITER_INFO);
            shouldUpdateState = bcIterInfo.get(0).f0.atLastInnerIterStep();
            shouldInitState = getIterationRuntimeContext().getSuperstepNumber() == 1;
            List<Tuple2<Long, DenseVector>> bcNewClusterCenters = getRuntimeContext().getBroadcastVariable(
                NEW_CLUSTER_CENTERS);
            newClusterCenters = new HashMap<>(0);
            bcNewClusterCenters.forEach(t -> newClusterCenters.put(t.f0, t.f1));
            if (distance instanceof EuclideanDistance) {
                middlePlanes = new HashMap<>(0);
                divisibleIndices.forEach(parentIndex -> {
                    long lchild = leftChildIndex(parentIndex);
                    long rchild = rightChildIndex(parentIndex);
                    DenseVector m = newClusterCenters.get(rchild).plus(newClusterCenters.get(lchild));
                    DenseVector v = newClusterCenters.get(rchild).minus(newClusterCenters.get(lchild));
                    BLAS.scal(0.5, m);
                    double length = BLAS.dot(m, v);
                    middlePlanes.put(parentIndex, Tuple2.of(v, length));
                });
            }
            if (shouldInitState) {
                assignmentInState = new ArrayList<>();
            }
        }

        @Override
        public void reduce(Iterable<Tuple4<Integer, Long, DenseVector, Long>> samples,
                           Collector<Tuple3<Long, DenseVector, Long>> out) {
            int pos = 0;
            for (Tuple4<Integer, Long, DenseVector, Long> sample : samples) {
                long parentClusterId = sample.f3;
                if (shouldInitState) {
                    assignmentInState.add(Tuple2.of(sample.f1, sample.f3));
                } else {
                    if (!sample.f1.equals(assignmentInState.get(pos).f0)) {
                        throw new RuntimeException("Data out of order.");
                    }
                    parentClusterId = assignmentInState.get(pos).f1;
                }
                if (divisibleIndices.contains(parentClusterId)) {
                    long leftChildIdx = leftChildIndex(parentClusterId);
                    long rightChildIdx = rightChildIndex(parentClusterId);
                    long clusterId;
                    if (distance instanceof EuclideanDistance) {
                        Tuple2<DenseVector, Double> plane = middlePlanes.get(parentClusterId);
                        double d = BLAS.dot(sample.f2, plane.f0);
                        clusterId = d < plane.f1 ? leftChildIdx : rightChildIdx;
                    } else {
                        clusterId = getClosestNode(leftChildIdx, newClusterCenters.get(leftChildIdx),
                            rightChildIdx, newClusterCenters.get(rightChildIdx), sample.f2, distance);
                    }
                    out.collect(Tuple3.of(sample.f1, sample.f2, clusterId));

                    if (shouldUpdateState) {
                        assignmentInState.set(pos, Tuple2.of(sample.f1, clusterId));
                    }
                }
                pos++;
            }
        }
    }

    public static long getClosestNode(long leftNode,
                                      DenseVector leftNodeVec,
                                      long rightNode,
                                      DenseVector rightNodeVec,
                                      DenseVector sample,
                                      ContinuousDistance distance) {
        double distanceLeft = distance.calc(sample, leftNodeVec);
        double distanceRight = distance.calc(sample, rightNodeVec);
        return distanceLeft < distanceRight ? leftNode : rightNode;
    }

    /**
     * According to the current sample distribution, get the cluster summary.
     *
     * @param assignment   <ClusterId, Vector> sample pair.
     * @param dim          vectorSize.
     * @param distanceType distance.
     * @return <ClusterId, ClusterSummary> pair.
     */
    private static DataSet<Tuple2<Long, ClusterSummary>>
    summary(DataSet<Tuple2<Long, DenseVector>> assignment, DataSet<Integer> dim, final DistanceType distanceType) {
        return assignment
            .mapPartition(
                new RichMapPartitionFunction<Tuple2<Long, DenseVector>,
                    Tuple2<Long, ClusterSummaryAggregator>>() {

                    @Override
                    public void mapPartition(Iterable<Tuple2<Long, DenseVector>> values,
                                             Collector<Tuple2<Long, ClusterSummaryAggregator>> out) {
                        Map<Long, ClusterSummaryAggregator> aggregators = new HashMap(0);
                        final int dim = (Integer)(getRuntimeContext().getBroadcastVariable(VECTOR_SIZE).get(0));
                        values.forEach(v -> {
                            ClusterSummaryAggregator aggregator = aggregators.getOrDefault(v.f0,
                                new ClusterSummaryAggregator(dim, distanceType));
                            aggregator.add(v.f1);
                            aggregators.putIfAbsent(v.f0, aggregator);
                        });
                        aggregators.forEach((k, v) -> out.collect(Tuple2.of(k, v)));
                    }
                })
            .name("local_aggregate_cluster_summary")
            .withBroadcastSet(dim, VECTOR_SIZE)
            .groupBy(0)
            .reduce(new ReduceFunction<Tuple2<Long, ClusterSummaryAggregator>>() {
                @Override
                public Tuple2<Long, ClusterSummaryAggregator> reduce(Tuple2<Long, ClusterSummaryAggregator> value1,
                                                                     Tuple2<Long, ClusterSummaryAggregator> value2) {
                    value1.f1.merge(value2.f1);
                    return value1;
                }
            })
            .name("global_aggregate_cluster_summary")
            .map(
                new MapFunction<Tuple2<Long, ClusterSummaryAggregator>, Tuple2<Long, ClusterSummary>>() {
                    @Override
                    public Tuple2<Long, ClusterSummary> map(
                        Tuple2<Long, ClusterSummaryAggregator> value) {
                        ClusterSummary summary = value.f1.toClusterSummary();
                        return Tuple2.of(value.f0, summary);
                    }
                })
            .withForwardedFields("f0")
            .name("make_cluster_summary");
    }

    private static DataSet<Tuple3<Long, ClusterSummary, IterInfo>>
    updateClusterSummariesAndIterInfo(
        DataSet<Tuple3<Long, ClusterSummary, IterInfo>> oldClusterSummariesWithIterInfo,
        DataSet<Tuple2<Long, ClusterSummary>> newClusterSummaries) {
        return oldClusterSummariesWithIterInfo
            .leftOuterJoin(newClusterSummaries)
            .where(0).equalTo(0)
            .with(
                new RichJoinFunction<Tuple3<Long, ClusterSummary, IterInfo>, Tuple2<Long,
                    ClusterSummary>,
                    Tuple3<Long, ClusterSummary, IterInfo>>() {

                    @Override
                    public Tuple3<Long, ClusterSummary, IterInfo> join(
                        Tuple3<Long, ClusterSummary, IterInfo> oldSummary,
                        Tuple2<Long, ClusterSummary> newSummary) {
                        if (newSummary == null) {
                            Preconditions.checkState(!oldSummary.f2.isNew, "Encounter an empty cluster: {}", oldSummary);
                            oldSummary.f2.updateIterInfo();
                            return oldSummary;
                        } else {
                            IterInfo iterInfo = oldSummary.f2;
                            iterInfo.updateIterInfo();
                            return Tuple3.of(newSummary.f0, newSummary.f1, iterInfo);
                        }
                    }
                })
            .name("update_model");
    }

    /**
     * The bisecting kmeans algorithm has nested loops. In the outer loop, cluster centers
     * are splited. In the inner loop, the splited centers are iteratively refined.
     * However, there lacks nested loop semantic in Flink, so we have to flatten the nested loop
     * in our implementation.
     */
    @Override
    public BisectingKMeansTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
        BatchOperator<?> in = checkAndGetFirst(inputs);

        // get the input parameter's value
        final DistanceType distanceType = getDistanceType();
        final int k = this.getK();
        final int maxIter = this.getMaxIter();
        final String vectorColName = this.getVectorCol();
        final int minDivisibleClusterSize = this.getMinDivisibleClusterSize();
        ContinuousDistance distance = distanceType.getFastDistance();

        Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> vectorsAndStat =
            StatisticsHelper.summaryHelper(in, null, vectorColName);

        DataSet<Integer> dim = vectorsAndStat.f1.map(new MapFunction<BaseVectorSummary, Integer>() {
            @Override
            public Integer map(BaseVectorSummary value) {
                Preconditions.checkArgument(value.count() > 0, "The train dataset is empty!");
                return value.vectorSize();
            }
        });

        // tuple: sampleId, features, assignment
        DataSet<Tuple3<Long, DenseVector, Long>> initialAssignment = DataSetUtils.zipWithUniqueId(vectorsAndStat.f0)
            .map(
                new MapFunction<Tuple2<Long, Vector>, Tuple3<Long, DenseVector, Long>>() {
                    @Override
                    public Tuple3<Long, DenseVector, Long> map(Tuple2<Long, Vector> value) {
                        return Tuple3.of(value.f0, (DenseVector)value.f1, ROOT_INDEX);
                    }
                });

        DataSet<Tuple2<Long, ClusterSummary>> clustersSummaries = summary(
            initialAssignment.project(2, 1), dim, distanceType);

        DataSet<Tuple3<Long, ClusterSummary, IterInfo>> clustersSummariesAndIterInfo
            = clustersSummaries
            .map(new MapFunction<Tuple2<Long, ClusterSummary>, Tuple3<Long, ClusterSummary, IterInfo>>() {
                @Override
                public Tuple3<Long, ClusterSummary, IterInfo> map(Tuple2<Long, ClusterSummary> value) {
                    return Tuple3.of(value.f0, value.f1, new IterInfo(maxIter));
                }
            })
            .withForwardedFields("f0;f1");

        IterativeDataSet<Tuple3<Long, ClusterSummary, IterInfo>> loop
            = clustersSummariesAndIterInfo.iterate(Integer.MAX_VALUE);
        DataSet<Tuple1<IterInfo>> iterInfo = loop.<Tuple1<IterInfo>>project(2).first(1);

        //Get all cluster summaries. Split clusters if at the first step of inner iterations.
        DataSet<Tuple3<Long, ClusterSummary, IterInfo>> allClusters = getOrSplitClusters(loop, k, minDivisibleClusterSize);

        DataSet<Long> divisibleClusterIndices = getDivisibleClusterIndices(allClusters);
        DataSet<Tuple2<Long, DenseVector>> newClusterCenters = getNewClusterCenters(allClusters);

        DataSet<Tuple3<Long, DenseVector, Long>> newAssignment = updateAssignment(
            initialAssignment, divisibleClusterIndices, newClusterCenters, distance, iterInfo);

        DataSet<Tuple2<Long, ClusterSummary>> newClusterSummaries = summary(
            newAssignment.project(2, 1), dim, distanceType);

        DataSet<Tuple3<Long, ClusterSummary, IterInfo>> updatedClusterSummariesWithIterInfo =
            updateClusterSummariesAndIterInfo(allClusters, newClusterSummaries);

        DataSet<Integer> stopCriterion = iterInfo
            .flatMap(new FlatMapFunction<Tuple1<IterInfo>, Integer>() {
                @Override
                public void flatMap(Tuple1<IterInfo> value, Collector<Integer> out) {
                    if (!(value.f0.atLastInnerIterStep() && value.f0.atLastBisectionStep())) {
                        out.collect(0);
                    }
                }
            });

        DataSet<Tuple2<Long, ClusterSummary>> finalClusterSummaries = loop
            .closeWith(updatedClusterSummariesWithIterInfo, stopCriterion)
            .project(0, 1);

        DataSet<Row> modelRows = finalClusterSummaries
            .mapPartition(new SaveModel(distanceType, vectorColName, k))
            .withBroadcastSet(dim, VECTOR_SIZE)
            .setParallelism(1);

        this.setOutput(modelRows, new BisectingKMeansModelDataConverter().getModelSchema());
        return this;
    }

    private static class SaveModel extends RichMapPartitionFunction<Tuple2<Long, ClusterSummary>, Row> {
        private DistanceType distanceType;
        private String vectorColName;
        private int k;

        SaveModel(DistanceType distanceType, String vectorColName, int k){
            this.distanceType = distanceType;
            this.vectorColName = vectorColName;
            this.k = k;
        }

        @Override
        public void mapPartition(Iterable<Tuple2<Long, ClusterSummary>> values,
                                 Collector<Row> out) {
            Preconditions.checkArgument(getRuntimeContext().getNumberOfParallelSubtasks() <= 1,
                "parallelism greater than one when saving model.");
            final int dim = (Integer)(getRuntimeContext().getBroadcastVariable(VECTOR_SIZE).get(0));
            BisectingKMeansModelData modelData = new BisectingKMeansModelData();
            modelData.summaries = new HashMap<>(0);
            modelData.vectorSize = dim;
            modelData.distanceType = distanceType;
            modelData.vectorColName = vectorColName;
            modelData.k = k;
            values.forEach(t -> modelData.summaries.put(t.f0, t.f1));
            new BisectingKMeansModelDataConverter().save(modelData, out);
        }
    }

    public static class ClusterSummaryAggregator implements Serializable {
        /**
         * Cluster sample number.
         */
        private long count;

        /**
         * Sum of cluster sample vector.
         */
        private DenseVector sum;

        /**
         * Sum of Cluster sample vector square.
         */
        private double sumSqured;

        private DistanceType distanceType;

        /**
         * The empty constructor is a 'must' to make it a POJO type.
         */
        ClusterSummaryAggregator() {
        }

        ClusterSummaryAggregator(int dim, DistanceType distanceType) {
            Preconditions.checkArgument(distanceType == DistanceType.EUCLIDEAN || distanceType == DistanceType.COSINE,
                "distanceType not support: {}", distanceType);
            sum = new DenseVector(dim);
            this.distanceType = distanceType;
        }

        public void add(DenseVector v) {
            count++;
            double norm = BLAS.dot(v, v);
            sumSqured += norm;

            if (distanceType == DistanceType.EUCLIDEAN) {
                BLAS.axpy(1., v, sum);
            } else {
                Preconditions.checkArgument(norm > 0, "Cosine Distance is not defined for zero-length vectors.");
                BLAS.axpy(1. / Math.sqrt(norm), v, sum);
            }
        }

        public void merge(ClusterSummaryAggregator other) {
            count += other.count;
            sumSqured += other.sumSqured;
            BLAS.axpy(1.0, other.sum, sum);
        }

        public ClusterSummary toClusterSummary() {
            ClusterSummary summary = new ClusterSummary();
            if (distanceType == DistanceType.EUCLIDEAN) {
                summary.center = sum.scale(1.0 / count);
            } else {
                summary.center = sum.scale(1.0 / count);
                summary.center.scaleEqual(1.0 / Math.sqrt(BLAS.dot(summary.center, summary.center)));
            }
            summary.cost = calcClusterCost(distanceType, summary.center, sum, count, sumSqured);
            summary.size = count;
            return summary;
        }

        private static double calcClusterCost(DistanceType distanceType,
                                              DenseVector center,
                                              DenseVector sum,
                                              long count,
                                              double sumSquared) {
            if (distanceType == DistanceType.EUCLIDEAN) {
                double centerL2NormSquared = BLAS.dot(center, center);
                double cost = sumSquared - count * centerL2NormSquared;
                return Math.max(cost, 0.);
            } else {
                double centerL2Norm = Math.sqrt(BLAS.dot(center, center));
                return Math.max(count - BLAS.dot(center, sum) / centerL2Norm, 0.0);
            }
        }
    }

    public static class IterInfo implements Serializable {
        /**
         * Bisecting Step Number
         */
        public int bisectingStepNo;
        /**
         * Innter Iter Step Number
         */
        public int innerIterStepNo;
        // maxIter of inner steps
        public int maxIter;
        public boolean isDividing = false;
        public boolean isNew = false;
        public boolean shouldStopSplit = false;

        // The empty constructor is a 'must' to make it a POJO type.
        public IterInfo() {
        }

        IterInfo(int maxIter) {
            this.maxIter = maxIter;
            this.bisectingStepNo = 0;
            this.innerIterStepNo = 0;
        }

        IterInfo(int maxIter, int bisectingStepNo, int innerIterStepNo, boolean isDividing, boolean isNew,
                 boolean shouldStopSplit) {
            this.bisectingStepNo = bisectingStepNo;
            this.innerIterStepNo = innerIterStepNo;
            this.maxIter = maxIter;
            this.isDividing = isDividing;
            this.isNew = isNew;
            this.shouldStopSplit = shouldStopSplit;
        }

        @Override
        public String toString() {
            return JsonConverter.toJson(this);
        }

        public void updateIterInfo() {
            innerIterStepNo++;
            if (innerIterStepNo >= maxIter) {
                bisectingStepNo++;
                innerIterStepNo = 0;
                isDividing = false;
                isNew = false;
            }
        }

        public boolean doBisectionInStep() {
            return innerIterStepNo == 0;
        }

        public boolean atLastInnerIterStep() {
            return innerIterStepNo == maxIter - 1;
        }

        public boolean atLastBisectionStep() {
            return shouldStopSplit;
        }
    }
}