# Copyright (c) 2018 Giphy Inc.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from collections import defaultdict
from sklearn.mixture import GaussianMixture


def clusterize(points, n_components=2, covariance_type='tied',
               centers=None, weights=None, output=None, random_state=1000):
    if centers is not None:
        n_components = len(centers)

    if output is None:
        output = points

    if len(points) < 2:
        return [list(output)]

    gmm = GaussianMixture(n_components=n_components,
                          covariance_type=covariance_type,
                          means_init=centers,
                          weights_init=weights,
                          random_state=random_state)
    gmm.fit(points)
    labels = gmm.predict(points)

    clusters = defaultdict(list)
    for label, point in zip(labels, output):
        clusters[label].append(point)

    return sorted(clusters.values(), key=lambda x: len(x), reverse=True)