Python sklearn.metrics.pairwise_distances_argmin() Examples

The following are 5 code examples of sklearn.metrics.pairwise_distances_argmin(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module sklearn.metrics , or try the search function .
Example #1
Source File: test_birch.py    From Mastering-Elasticsearch-7.0 with MIT License 6 votes vote down vote up
def test_birch_predict():
    # Test the predict method predicts the nearest centroid.
    rng = np.random.RandomState(0)
    X = generate_clustered_data(n_clusters=3, n_features=3,
                                n_samples_per_cluster=10)

    # n_samples * n_samples_per_cluster
    shuffle_indices = np.arange(30)
    rng.shuffle(shuffle_indices)
    X_shuffle = X[shuffle_indices, :]
    brc = Birch(n_clusters=4, threshold=1.)
    brc.fit(X_shuffle)
    centroids = brc.subcluster_centers_
    assert_array_equal(brc.labels_, brc.predict(X_shuffle))
    nearest_centroid = pairwise_distances_argmin(X_shuffle, centroids)
    assert_almost_equal(v_measure_score(nearest_centroid, brc.labels_), 1.0) 
Example #2
Source File: test_birch.py    From twitter-stock-recommendation with MIT License 6 votes vote down vote up
def test_birch_predict():
    # Test the predict method predicts the nearest centroid.
    rng = np.random.RandomState(0)
    X = generate_clustered_data(n_clusters=3, n_features=3,
                                n_samples_per_cluster=10)

    # n_samples * n_samples_per_cluster
    shuffle_indices = np.arange(30)
    rng.shuffle(shuffle_indices)
    X_shuffle = X[shuffle_indices, :]
    brc = Birch(n_clusters=4, threshold=1.)
    brc.fit(X_shuffle)
    centroids = brc.subcluster_centers_
    assert_array_equal(brc.labels_, brc.predict(X_shuffle))
    nearest_centroid = pairwise_distances_argmin(X_shuffle, centroids)
    assert_almost_equal(v_measure_score(nearest_centroid, brc.labels_), 1.0) 
Example #3
Source File: prediction_strength.py    From theMLbook with MIT License 5 votes vote down vote up
def find_clusters(x, n_clusters, current_split):

    current_split_suffled = list(x_split[current_split])[:]
    shuffle(current_split_suffled)
    current_split_suffled = np.array(current_split_suffled)

    centroids = np.array(current_split_suffled[:n_clusters])

    while True:

        # assign labels based on closest centroid
        #print centroids

        #print "len train", len(x_split[current_split])
        labels = pairwise_distances_argmin(x_split[current_split], centroids)
        #print "len labels", len(labels)

        
        # find new centroids as the average of examples
        new_centroids = np.array([x_split[current_split][labels == i].mean(0) for i in range(n_clusters)])
        
        # check for convergence
        if np.all(centroids == new_centroids):
            break
        centroids = new_centroids

    return centroids, labels 
Example #4
Source File: _idle.py    From CO2MPAS-TA with European Union Public License 1.1 5 votes vote down vote up
def predict(self, X, set_outliers=True):
        import sklearn.metrics as sk_met
        y = sk_met.pairwise_distances_argmin(X, self.cluster_centers_[:, None])
        if set_outliers:
            y[((X > self.max) | (X < self.min))[:, 0]] = -1
        return y 
Example #5
Source File: kmeans_pruning.py    From Ridurre-Network-Filter-Pruning-Keras with MIT License 4 votes vote down vote up
def run_pruning_for_conv2d_layer(self, pruning_factor: float, layer: layers.Conv2D, layer_weight_mtx) -> List[int]:
        _, _, _, nb_channels = layer_weight_mtx.shape

        # Initialize KMeans
        nb_of_clusters, _ = self._calculate_number_of_channels_to_keep(pruning_factor, nb_channels)
        kmeans = cluster.KMeans(nb_of_clusters, "k-means++")

        # Fit with the flattened weight matrix
        # (height, width, input_channels, output_channels) -> (output_channels, flattened features)
        layer_weight_mtx_reshaped = layer_weight_mtx.transpose(3, 0, 1, 2).reshape(nb_channels, -1)
        # Apply some fuzz to the weights, to avoid duplicates
        self._apply_fuzz(layer_weight_mtx_reshaped)
        kmeans.fit(layer_weight_mtx_reshaped)

        # If a cluster has only a single member, then that should not be pruned
        # so that point will always be the closest to the cluster center
        closest_point_to_cluster_center_indices = metrics.pairwise_distances_argmin(kmeans.cluster_centers_,
                                                                                    layer_weight_mtx_reshaped)
        # Compute filter indices which can be pruned
        channel_indices = set(np.arange(len(layer_weight_mtx_reshaped)))
        channel_indices_to_keep = set(closest_point_to_cluster_center_indices)
        channel_indices_to_prune = list(channel_indices.difference(channel_indices_to_keep))
        channel_indices_to_keep = list(channel_indices_to_keep)

        if len(channel_indices_to_keep) > nb_of_clusters:
            print("Number of selected channels for pruning is less than expected")
            diff = len(channel_indices_to_keep) - nb_of_clusters
            print("Randomly adding {0} channels for pruning".format(diff))
            np.random.shuffle(channel_indices_to_keep)
            for i in range(diff):
                channel_indices_to_prune.append(channel_indices_to_keep.pop(i))
        elif len(channel_indices_to_keep) < nb_of_clusters:
            print("Number of selected channels for pruning is greater than expected. Leaving too few channels.")
            diff = nb_of_clusters - len(channel_indices_to_keep)
            print("Discarding {0} pruneable channels".format(diff))
            for i in range(diff):
                channel_indices_to_keep.append(channel_indices_to_prune.pop(i))

        if len(channel_indices_to_keep) != nb_of_clusters:
            raise ValueError(
                "Number of clusters {0} is not equal with the selected "
                "pruneable channels {1}".format(nb_of_clusters, len(channel_indices_to_prune)))

        return channel_indices_to_prune